From 7eca43f29199e035416c529b8db00ed00e4d4080 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Fri, 27 Oct 2023 13:11:54 +0200 Subject: [PATCH 1/2] SNOW-856228 enable easy logging (#946) SNOW-856228 enable easy logging --- client_configuration.go | 6 +- client_configuration_test.go | 53 ++++------ connection.go | 4 + dsn.go | 7 ++ dsn_test.go | 51 +++++++++ easy_logging.go | 153 +++++++++++++++++++++++++++ easy_logging_test.go | 190 +++++++++++++++++++++++++++++++++ errors.go | 3 + log.go | 199 +++++++++++++++++++++++++++-------- log_test.go | 34 ++++++ 10 files changed, 627 insertions(+), 73 deletions(-) create mode 100644 easy_logging.go create mode 100644 easy_logging_test.go diff --git a/client_configuration.go b/client_configuration.go index 52d55de73..c3b3573d9 100644 --- a/client_configuration.go +++ b/client_configuration.go @@ -30,7 +30,7 @@ func getClientConfig(filePathFromConnectionString string) (*ClientConfig, error) configPredefinedFilePaths := clientConfigPredefinedDirs() filePath, err := findClientConfigFilePath(filePathFromConnectionString, configPredefinedFilePaths) if err != nil { - return nil, err + return nil, findClientConfigError(err) } if filePath == "" { // we did not find a config file return nil, nil @@ -63,6 +63,10 @@ func searchForConfigFile(directories []string) (string, error) { return "", nil } +func findClientConfigError(err error) error { + return fmt.Errorf("finding client config failed: %w", err) +} + func existsFile(filePath string) (bool, error) { _, err := os.Stat(filePath) if err == nil { diff --git a/client_configuration_test.go b/client_configuration_test.go index a63eb6753..14292cc8a 100644 --- a/client_configuration_test.go +++ b/client_configuration_test.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path" + "strings" "testing" ) @@ -84,12 +85,7 @@ func TestCreatePredefinedDirs(t *testing.T) { func TestGetClientConfig(t *testing.T) { dir := t.TempDir() fileName := "config.json" - configContents := `{ - "common": { - "log_level" : "INFO", - "log_path" : "/some-path/some-directory" - } - }` + configContents := createClientConfigContent("INFO", "/some-path/some-directory") createFile(t, fileName, configContents, dir) filePath := path.Join(dir, fileName) @@ -118,26 +114,16 @@ func TestParseConfiguration(t *testing.T) { expectedLogPath string }{ { - testName: "TestWithLogLevelUpperCase", - fileName: "config_1.json", - fileContents: `{ - "common": { - "log_level" : "INFO", - "log_path" : "/some-path/some-directory" - } - }`, + testName: "TestWithLogLevelUpperCase", + fileName: "config_1.json", + fileContents: createClientConfigContent("INFO", "/some-path/some-directory"), expectedLogLevel: "INFO", expectedLogPath: "/some-path/some-directory", }, { - testName: "TestWithLogLevelLowerCase", - fileName: "config_2.json", - fileContents: `{ - "common": { - "log_level" : "info", - "log_path" : "/some-path/some-directory" - } - }`, + testName: "TestWithLogLevelLowerCase", + fileName: "config_2.json", + fileContents: createClientConfigContent("info", "/some-path/some-directory"), expectedLogLevel: "info", expectedLogPath: "/some-path/some-directory", }, @@ -193,14 +179,9 @@ func TestParseConfigurationFails(t *testing.T) { expectedErrorMessageToContain string }{ { - testName: "TestWithWrongLogLevel", - fileName: "config_1.json", - FileContents: `{ - "common": { - "log_level" : "something weird", - "log_path" : "/some-path/some-directory" - } - }`, + testName: "TestWithWrongLogLevel", + fileName: "config_1.json", + FileContents: createClientConfigContent("something weird", "/some-path/some-directory"), expectedErrorMessageToContain: "unknown log level", }, { @@ -284,3 +265,15 @@ func predefinedTestDirs(dirs struct { }) []string { return []string{dirs.predefinedDir1, dirs.predefinedDir2} } + +func createClientConfigContent(logLevel string, logPath string) string { + return fmt.Sprintf(`{ + "common": { + "log_level" : "%s", + "log_path" : "%s" + } + }`, + logLevel, + strings.ReplaceAll(logPath, "\\", "\\\\"), + ) +} diff --git a/connection.go b/connection.go index 6fcc2c437..087ac7215 100644 --- a/connection.go +++ b/connection.go @@ -737,6 +737,10 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err queryContextCache: (&queryContextCache{}).init(), currentTimeProvider: defaultTimeProvider, } + err := initEasyLogging(config.ClientConfigFile) + if err != nil { + return nil, err + } var st http.RoundTripper = SnowflakeTransport if sc.cfg.Transporter == nil { if sc.cfg.InsecureMode { diff --git a/dsn.go b/dsn.go index c365e1506..b3a8212f7 100644 --- a/dsn.go +++ b/dsn.go @@ -101,6 +101,8 @@ type Config struct { DisableQueryContextCache bool // Should HTAP query context cache be disabled IncludeRetryReason ConfigBool // Should retried request contain retry reason + + ClientConfigFile string // File path to the client configuration json file } // Validate enables testing if config is correct. @@ -252,6 +254,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.ClientStoreTemporaryCredential != configBoolNotSet { params.Add("clientStoreTemporaryCredential", strconv.FormatBool(cfg.ClientStoreTemporaryCredential != ConfigBoolFalse)) } + if cfg.ClientConfigFile != "" { + params.Add("clientConfigFile", cfg.ClientConfigFile) + } dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port) if params.Encode() != "" { @@ -734,6 +739,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { } else { cfg.IncludeRetryReason = ConfigBoolFalse } + case "clientConfigFile": + cfg.ClientConfigFile = value default: if cfg.Params == nil { cfg.Params = make(map[string]*string) diff --git a/dsn_test.go b/dsn_test.go index e3e30d695..e49207bd5 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -659,6 +659,36 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&includeRetryReason=true&clientConfigFile=%2FUsers%2Fuser%2Fconfig.json", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, + ClientConfigFile: "/Users/user/config.json", + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&includeRetryReason=true&clientConfigFile=c%3A%5CUsers%5Cuser%5Cconfig.json", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, + ClientConfigFile: "c:\\Users\\user\\config.json", + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { @@ -818,6 +848,7 @@ func TestParseDSN(t *testing.T) { if test.config.IncludeRetryReason != cfg.IncludeRetryReason { t.Fatalf("%v: Failed to match IncludeRetryReason. expected: %v, got: %v", i, test.config.IncludeRetryReason, cfg.IncludeRetryReason) } + assertEqualF(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file") case test.err != nil: driverErrE, okE := test.err.(*SnowflakeError) driverErrG, okG := err.(*SnowflakeError) @@ -1236,6 +1267,26 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + IncludeRetryReason: ConfigBoolTrue, + ClientConfigFile: "/Users/user/config.json", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientConfigFile=%2FUsers%2Fuser%2Fconfig.json&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + IncludeRetryReason: ConfigBoolTrue, + ClientConfigFile: "c:\\Users\\user\\config.json", + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientConfigFile=c%3A%5CUsers%5Cuser%5Cconfig.json&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, } for _, test := range testcases { t.Run(test.dsn, func(t *testing.T) { diff --git a/easy_logging.go b/easy_logging.go new file mode 100644 index 000000000..e61b1b244 --- /dev/null +++ b/easy_logging.go @@ -0,0 +1,153 @@ +package gosnowflake + +import ( + "errors" + "io" + "os" + "path" + "strings" +) + +type initTrials struct { + everTriedToInitialize bool + clientConfigFileInput string + configureCounter int +} + +var easyLoggingInitTrials = initTrials{ + everTriedToInitialize: false, + clientConfigFileInput: "", + configureCounter: 0, +} + +func (i *initTrials) setInitTrial(clientConfigFileInput string) { + i.everTriedToInitialize = true + i.clientConfigFileInput = clientConfigFileInput +} + +func (i *initTrials) increaseReconfigureCounter() { + i.configureCounter++ +} + +func (i *initTrials) reset() { + i.everTriedToInitialize = false + i.clientConfigFileInput = "" + i.configureCounter = 0 +} + +func initEasyLogging(clientConfigFileInput string) error { + if !allowedToInitialize(clientConfigFileInput) { + return nil + } + config, err := getClientConfig(clientConfigFileInput) + if err != nil { + return easyLoggingInitError(err) + } + if config == nil { + easyLoggingInitTrials.setInitTrial(clientConfigFileInput) + return nil + } + var logLevel string + logLevel, err = getLogLevel(config.Common.LogLevel) + if err != nil { + return easyLoggingInitError(err) + } + var logPath string + logPath, err = getLogPath(config.Common.LogPath) + if err != nil { + return easyLoggingInitError(err) + } + err = reconfigureEasyLogging(logLevel, logPath) + easyLoggingInitTrials.setInitTrial(clientConfigFileInput) + easyLoggingInitTrials.increaseReconfigureCounter() + return err +} + +func easyLoggingInitError(err error) error { + return &SnowflakeError{ + Number: ErrCodeClientConfigFailed, + Message: errMsgClientConfigFailed, + MessageArgs: []interface{}{err.Error()}, + } +} + +func reconfigureEasyLogging(logLevel string, logPath string) error { + newLogger := CreateDefaultLogger() + err := newLogger.SetLogLevel(logLevel) + if err != nil { + return err + } + var output io.Writer + var file *os.File + output, file, err = createLogWriter(logPath) + if err != nil { + return err + } + newLogger.SetOutput(output) + err = newLogger.CloseFileOnLoggerReplace(file) + if err != nil { + logger.Errorf("%s", err) + } + logger.Replace(&newLogger) + return nil +} + +func createLogWriter(logPath string) (io.Writer, *os.File, error) { + if strings.EqualFold(logPath, "STDOUT") { + return os.Stdout, nil, nil + } + logFileName := path.Join(logPath, "snowflake.log") + file, err := os.OpenFile(logFileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640) + if err != nil { + return nil, nil, err + } + return io.MultiWriter(file, os.Stdout), file, nil +} + +func allowedToInitialize(clientConfigFileInput string) bool { + triedToInitializeWithoutConfigFile := easyLoggingInitTrials.everTriedToInitialize && easyLoggingInitTrials.clientConfigFileInput == "" + isAllowedToInitialize := !easyLoggingInitTrials.everTriedToInitialize || (triedToInitializeWithoutConfigFile && clientConfigFileInput != "") + if !isAllowedToInitialize && easyLoggingInitTrials.clientConfigFileInput != clientConfigFileInput { + logger.Warnf("Easy logging will not be configured for CLIENT_CONFIG_FILE=%s because it was previously configured for a different client config", clientConfigFileInput) + } + return isAllowedToInitialize +} + +func getLogLevel(logLevel string) (string, error) { + if logLevel == "" { + logger.Warn("LogLevel in client config not found. Using default value: OFF") + return levelOff, nil + } + return toLogLevel(logLevel) +} + +func getLogPath(logPath string) (string, error) { + logPathOrDefault := logPath + if logPath == "" { + logPathOrDefault = os.TempDir() + logger.Warnf("LogPath in client config not found. Using temporary directory as a default value: %s", logPathOrDefault) + } + pathWithGoSubdir := path.Join(logPathOrDefault, "go") + exists, err := dirExists(pathWithGoSubdir) + if err != nil { + return "", err + } + if !exists { + err = os.MkdirAll(pathWithGoSubdir, 0755) + if err != nil { + return "", err + } + } + return pathWithGoSubdir, nil +} + +func dirExists(dirPath string) (bool, error) { + stat, err := os.Stat(dirPath) + if err == nil { + return stat.IsDir(), nil + } + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + return false, err +} diff --git a/easy_logging_test.go b/easy_logging_test.go new file mode 100644 index 000000000..f9245807b --- /dev/null +++ b/easy_logging_test.go @@ -0,0 +1,190 @@ +package gosnowflake + +import ( + "context" + "fmt" + "os" + "path" + "strings" + "testing" +) + +func TestInitializeEasyLoggingOnlyOnceWhenConfigGivenAsAParameter(t *testing.T) { + defer cleanUp() + dir := t.TempDir() + logLevel := levelError + contents := createClientConfigContent(logLevel, dir) + configFilePath := createFile(t, "config.json", contents, dir) + easyLoggingInitTrials.reset() + + err := openWithClientConfigFile(t, configFilePath) + + assertNilF(t, err, "open config error") + assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), logLevel, "error log level check") + assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) + + err = openWithClientConfigFile(t, "") + assertNilF(t, err, "open config error") + err = openWithClientConfigFile(t, configFilePath) + assertNilF(t, err, "open config error") + err = openWithClientConfigFile(t, "/another-config.json") + assertNilF(t, err, "open config error") + + assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), logLevel, "error log level check") + assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) +} + +func TestConfigureEasyLoggingOnlyOnceWhenInitializedWithoutConfigFilePath(t *testing.T) { + defer cleanUp() + dir := t.TempDir() + logLevel := levelError + contents := createClientConfigContent(logLevel, dir) + configFilePath := createFile(t, defaultConfigName, contents, os.TempDir()) + defer os.Remove(configFilePath) + easyLoggingInitTrials.reset() + + err := openWithClientConfigFile(t, "") + assertNilF(t, err, "open config error") + err = openWithClientConfigFile(t, "") + assertNilF(t, err, "open config error") + + assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), logLevel, "error log level check") + assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) +} + +func TestReconfigureEasyLoggingIfConfigPathWasNotGivenForTheFirstTime(t *testing.T) { + defer cleanUp() + dir := t.TempDir() + tmpDirLogLevel := levelError + tmpFileContent := createClientConfigContent(tmpDirLogLevel, dir) + tmpDirConfigFilePath := createFile(t, defaultConfigName, tmpFileContent, os.TempDir()) + defer os.Remove(tmpDirConfigFilePath) + customLogLevel := levelWarn + customFileContent := createClientConfigContent(customLogLevel, dir) + customConfigFilePath := createFile(t, "config.json", customFileContent, dir) + easyLoggingInitTrials.reset() + + err := openWithClientConfigFile(t, "") + logger.Error("Error message") + + assertNilF(t, err, "open config error") + assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), tmpDirLogLevel, "tmp dir log level check") + assertEqualE(t, easyLoggingInitTrials.configureCounter, 1) + + err = openWithClientConfigFile(t, customConfigFilePath) + logger.Error("Warning message") + + assertNilF(t, err, "open config error") + assertEqualE(t, toClientConfigLevel(logger.GetLogLevel()), customLogLevel, "custom dir log level check") + assertEqualE(t, easyLoggingInitTrials.configureCounter, 2) + var logContents []byte + logContents, err = os.ReadFile(path.Join(dir, "go", "snowflake.log")) + assertNilF(t, err, "read file error") + logs := notEmptyLines(string(logContents)) + assertEqualE(t, len(logs), 2, "number of logs") +} + +func TestEasyLoggingFailOnUnknownLevel(t *testing.T) { + defer cleanUp() + dir := t.TempDir() + easyLoggingInitTrials.reset() + configContent := createClientConfigContent("something_unknown", dir) + configFilePath := createFile(t, "config.json", configContent, dir) + + err := openWithClientConfigFile(t, configFilePath) + + assertNotNilF(t, err, "open config error") + assertStringContainsE(t, err.Error(), fmt.Sprint(ErrCodeClientConfigFailed), "error code") + assertStringContainsE(t, err.Error(), "parsing client config failed", "error message") +} + +func TestEasyLoggingFailOnNotExistingConfigFile(t *testing.T) { + defer cleanUp() + easyLoggingInitTrials.reset() + + err := openWithClientConfigFile(t, "/not-existing-file.json") + + assertNotNilF(t, err, "open config error") + assertStringContainsE(t, err.Error(), fmt.Sprint(ErrCodeClientConfigFailed), "error code") + assertStringContainsE(t, err.Error(), "parsing client config failed", "error message") +} + +func TestLogToConfiguredFile(t *testing.T) { + defer cleanUp() + dir := t.TempDir() + easyLoggingInitTrials.reset() + configContent := createClientConfigContent(levelWarn, dir) + configFilePath := createFile(t, "config.json", configContent, dir) + logFilePath := path.Join(dir, "go", "snowflake.log") + err := openWithClientConfigFile(t, configFilePath) + assertNilF(t, err, "open config error") + + logger.Error("Error message") + logger.Warn("Warning message") + logger.Warning("Warning message") + logger.Info("Info message") + logger.Trace("Trace message") + + var logContents []byte + logContents, err = os.ReadFile(logFilePath) + assertNilF(t, err, "read file error") + logs := notEmptyLines(string(logContents)) + assertEqualE(t, len(logs), 3, "number of logs") + errorLogs := filterStrings(logs, func(val string) bool { + return strings.Contains(val, "level=error") + }) + assertEqualE(t, len(errorLogs), 1, "error logs count") + warningLogs := filterStrings(logs, func(val string) bool { + return strings.Contains(val, "level=warning") + }) + assertEqualE(t, len(warningLogs), 2, "warning logs count") +} + +func notEmptyLines(lines string) []string { + notEmptyFunc := func(val string) bool { + return val != "" + } + return filterStrings(strings.Split(strings.ReplaceAll(lines, "\r\n", "\n"), "\n"), notEmptyFunc) +} + +func cleanUp() { + newLogger := CreateDefaultLogger() + logger.Replace(&newLogger) + easyLoggingInitTrials.reset() +} + +func toClientConfigLevel(logLevel string) string { + logLevelUpperCase := strings.ToUpper(logLevel) + switch strings.ToUpper(logLevel) { + case "WARNING": + return levelWarn + case levelOff, levelError, levelWarn, levelInfo, levelDebug, levelTrace: + return logLevelUpperCase + default: + return "" + } +} + +func filterStrings(values []string, keep func(string) bool) []string { + filteredStrings := []string{} + for _, val := range values { + if keep(val) { + filteredStrings = append(filteredStrings, val) + } + } + return filteredStrings +} + +func defaultConfig(t *testing.T) *Config { + config, err := ParseDSN(dsn) + assertNilF(t, err, "parse dsn error") + return config +} + +func openWithClientConfigFile(t *testing.T, clientConfigFile string) error { + driver := SnowflakeDriver{} + config := defaultConfig(t) + config.ClientConfigFile = clientConfigFile + _, err := driver.OpenWithConfig(context.Background(), *config) + return err +} diff --git a/errors.go b/errors.go index 64a77ac5f..1f1ad17a8 100644 --- a/errors.go +++ b/errors.go @@ -125,6 +125,8 @@ const ( ErrCodePrivateKeyParseError = 260010 // ErrCodeFailedToParseAuthenticator is an error code for the case where a DNS includes an invalid authenticator ErrCodeFailedToParseAuthenticator = 260011 + // ErrCodeClientConfigFailed is an error code for the case where clientConfigFile is invalid or applying client configuration fails + ErrCodeClientConfigFailed = 260012 /* network */ @@ -290,6 +292,7 @@ const ( errMsgNoResultIDs = "no result IDs returned with the multi-statement query" errMsgQueryStatus = "server ErrorCode=%s, ErrorMessage=%s" errMsgInvalidPadding = "invalid padding on input" + errMsgClientConfigFailed = "client configuration failed: %v" ) // Returned if a DNS doesn't include account parameter. diff --git a/log.go b/log.go index 87fb69866..b48294cb6 100644 --- a/log.go +++ b/log.go @@ -7,8 +7,10 @@ import ( "fmt" rlog "github.com/sirupsen/logrus" "io" + "os" "path" "runtime" + "strings" "time" ) @@ -25,8 +27,11 @@ var LogKeys = [...]contextKey{SFSessionIDKey, SFSessionUserKey} type SFLogger interface { rlog.Ext1FieldLogger SetLogLevel(level string) error + GetLogLevel() string WithContext(ctx context.Context) *rlog.Entry SetOutput(output io.Writer) + CloseFileOnLoggerReplace(file *os.File) error + Replace(newLogger *SFLogger) } // SFCallerPrettyfier to provide base file name and function name from calling frame used in SFLogger @@ -35,19 +40,57 @@ func SFCallerPrettyfier(frame *runtime.Frame) (string, string) { } type defaultLogger struct { - inner *rlog.Logger + inner *rlog.Logger + enabled bool + file *os.File } // SetLogLevel set logging level for calling defaultLogger func (log *defaultLogger) SetLogLevel(level string) error { - actualLevel, err := rlog.ParseLevel(level) - if err != nil { - return err + newEnabled := strings.ToUpper(level) != "OFF" + log.enabled = newEnabled + if newEnabled { + actualLevel, err := rlog.ParseLevel(level) + if err != nil { + return err + } + log.inner.SetLevel(actualLevel) } - log.inner.SetLevel(actualLevel) return nil } +// GetLogLevel return current log level +func (log *defaultLogger) GetLogLevel() string { + if !log.enabled { + return "OFF" + } + return log.inner.GetLevel().String() +} + +// CloseFileOnLoggerReplace set a file to be closed when releasing resources occupied by the logger +func (log *defaultLogger) CloseFileOnLoggerReplace(file *os.File) error { + if log.file != nil && log.file != file { + return fmt.Errorf("could not set a file to close on logger reset because there were already set one") + } + log.file = file + return nil +} + +// Replace substitute logger by a given one +func (log *defaultLogger) Replace(newLogger *SFLogger) { + SetLogger(newLogger) + closeLogFile(log.file) +} + +func closeLogFile(file *os.File) { + if file != nil { + err := file.Close() + if err != nil { + logger.Errorf("failed to close log file: %s", err) + } + } +} + // WithContext return Entry to include fields in context func (log *defaultLogger) WithContext(ctx context.Context) *rlog.Entry { fields := context2Fields(ctx) @@ -60,7 +103,7 @@ func CreateDefaultLogger() SFLogger { var formatter = rlog.TextFormatter{CallerPrettyfier: SFCallerPrettyfier} rLogger.SetReportCaller(true) rLogger.SetFormatter(&formatter) - var ret = defaultLogger{inner: rLogger} + var ret = defaultLogger{inner: rLogger, enabled: true} return &ret //(&ret).(*SFLogger) } @@ -95,39 +138,57 @@ func (log *defaultLogger) Logf(level rlog.Level, format string, args ...interfac } func (log *defaultLogger) Tracef(format string, args ...interface{}) { - log.inner.Tracef(format, args...) + if log.enabled { + log.inner.Tracef(format, args...) + } } func (log *defaultLogger) Debugf(format string, args ...interface{}) { - log.inner.Debugf(format, args...) + if log.enabled { + log.inner.Debugf(format, args...) + } } func (log *defaultLogger) Infof(format string, args ...interface{}) { - log.inner.Infof(format, args...) + if log.enabled { + log.inner.Infof(format, args...) + } } func (log *defaultLogger) Printf(format string, args ...interface{}) { - log.inner.Printf(format, args...) + if log.enabled { + log.inner.Printf(format, args...) + } } func (log *defaultLogger) Warnf(format string, args ...interface{}) { - log.inner.Warnf(format, args...) + if log.enabled { + log.inner.Warnf(format, args...) + } } func (log *defaultLogger) Warningf(format string, args ...interface{}) { - log.inner.Warningf(format, args...) + if log.enabled { + log.inner.Warningf(format, args...) + } } func (log *defaultLogger) Errorf(format string, args ...interface{}) { - log.inner.Errorf(format, args...) + if log.enabled { + log.inner.Errorf(format, args...) + } } func (log *defaultLogger) Fatalf(format string, args ...interface{}) { - log.inner.Fatalf(format, args...) + if log.enabled { + log.inner.Fatalf(format, args...) + } } func (log *defaultLogger) Panicf(format string, args ...interface{}) { - log.inner.Panicf(format, args...) + if log.enabled { + log.inner.Panicf(format, args...) + } } func (log *defaultLogger) Log(level rlog.Level, args ...interface{}) { @@ -139,75 +200,111 @@ func (log *defaultLogger) LogFn(level rlog.Level, fn rlog.LogFunction) { } func (log *defaultLogger) Trace(args ...interface{}) { - log.inner.Trace(args...) + if log.enabled { + log.inner.Trace(args...) + } } func (log *defaultLogger) Debug(args ...interface{}) { - log.inner.Debug(args...) + if log.enabled { + log.inner.Debug(args...) + } } func (log *defaultLogger) Info(args ...interface{}) { - log.inner.Info(args...) + if log.enabled { + log.inner.Info(args...) + } } func (log *defaultLogger) Print(args ...interface{}) { - log.inner.Print(args...) + if log.enabled { + log.inner.Print(args...) + } } func (log *defaultLogger) Warn(args ...interface{}) { - log.inner.Warn(args...) + if log.enabled { + log.inner.Warn(args...) + } } func (log *defaultLogger) Warning(args ...interface{}) { - log.inner.Warning(args...) + if log.enabled { + log.inner.Warning(args...) + } } func (log *defaultLogger) Error(args ...interface{}) { - log.inner.Error(args...) + if log.enabled { + log.inner.Error(args...) + } } func (log *defaultLogger) Fatal(args ...interface{}) { - log.inner.Fatal(args...) + if log.enabled { + log.inner.Fatal(args...) + } } func (log *defaultLogger) Panic(args ...interface{}) { - log.inner.Panic(args...) + if log.enabled { + log.inner.Panic(args...) + } } func (log *defaultLogger) TraceFn(fn rlog.LogFunction) { - log.inner.TraceFn(fn) + if log.enabled { + log.inner.TraceFn(fn) + } } func (log *defaultLogger) DebugFn(fn rlog.LogFunction) { - log.inner.DebugFn(fn) + if log.enabled { + log.inner.DebugFn(fn) + } } func (log *defaultLogger) InfoFn(fn rlog.LogFunction) { - log.inner.InfoFn(fn) + if log.enabled { + log.inner.InfoFn(fn) + } } func (log *defaultLogger) PrintFn(fn rlog.LogFunction) { - log.inner.PrintFn(fn) + if log.enabled { + log.inner.PrintFn(fn) + } } func (log *defaultLogger) WarnFn(fn rlog.LogFunction) { - log.inner.PrintFn(fn) + if log.enabled { + log.inner.PrintFn(fn) + } } func (log *defaultLogger) WarningFn(fn rlog.LogFunction) { - log.inner.WarningFn(fn) + if log.enabled { + log.inner.WarningFn(fn) + } } func (log *defaultLogger) ErrorFn(fn rlog.LogFunction) { - log.inner.ErrorFn(fn) + if log.enabled { + log.inner.ErrorFn(fn) + } } func (log *defaultLogger) FatalFn(fn rlog.LogFunction) { - log.inner.FatalFn(fn) + if log.enabled { + log.inner.FatalFn(fn) + } } func (log *defaultLogger) PanicFn(fn rlog.LogFunction) { - log.inner.PanicFn(fn) + if log.enabled { + log.inner.PanicFn(fn) + } } func (log *defaultLogger) Logln(level rlog.Level, args ...interface{}) { @@ -215,39 +312,57 @@ func (log *defaultLogger) Logln(level rlog.Level, args ...interface{}) { } func (log *defaultLogger) Traceln(args ...interface{}) { - log.inner.Traceln(args...) + if log.enabled { + log.inner.Traceln(args...) + } } func (log *defaultLogger) Debugln(args ...interface{}) { - log.inner.Debugln(args...) + if log.enabled { + log.inner.Debugln(args...) + } } func (log *defaultLogger) Infoln(args ...interface{}) { - log.inner.Infoln(args...) + if log.enabled { + log.inner.Infoln(args...) + } } func (log *defaultLogger) Println(args ...interface{}) { - log.inner.Println(args...) + if log.enabled { + log.inner.Println(args...) + } } func (log *defaultLogger) Warnln(args ...interface{}) { - log.inner.Warnln(args...) + if log.enabled { + log.inner.Warnln(args...) + } } func (log *defaultLogger) Warningln(args ...interface{}) { - log.inner.Warningln(args...) + if log.enabled { + log.inner.Warningln(args...) + } } func (log *defaultLogger) Errorln(args ...interface{}) { - log.inner.Errorln(args...) + if log.enabled { + log.inner.Errorln(args...) + } } func (log *defaultLogger) Fatalln(args ...interface{}) { - log.inner.Fatalln(args...) + if log.enabled { + log.inner.Fatalln(args...) + } } func (log *defaultLogger) Panicln(args ...interface{}) { - log.inner.Panicln(args...) + if log.enabled { + log.inner.Panicln(args...) + } } func (log *defaultLogger) Exit(code int) { diff --git a/log_test.go b/log_test.go index b585c6411..e4ab43b88 100644 --- a/log_test.go +++ b/log_test.go @@ -106,6 +106,40 @@ func TestDefaultLogLevel(t *testing.T) { } } +func TestOffLogLevel(t *testing.T) { + logger := CreateDefaultLogger() + buf := &bytes.Buffer{} + logger.SetOutput(buf) + err := logger.SetLogLevel("OFF") + assertNilF(t, err) + SetLogger(&logger) + + logger.Info("info") + logger.Infof("info%v", "f") + logger.Infoln("infoln") + logger.Debug("debug") + logger.Debugf("debug%v", "f") + logger.Debugln("debugln") + logger.Trace("trace") + logger.Tracef("trace%v", "f") + logger.Traceln("traceln") + logger.Print("print") + logger.Printf("print%v", "f") + logger.Println("println") + logger.Warn("warn") + logger.Warnf("warn%v", "f") + logger.Warnln("warnln") + logger.Warning("warning") + logger.Warningf("warning%v", "f") + logger.Warningln("warningln") + logger.Error("error") + logger.Errorf("error%v", "f") + logger.Errorln("errorln") + + assertEqualE(t, buf.Len(), 0, "log messages count") + assertEqualE(t, logger.GetLogLevel(), "OFF", "log level") +} + func TestLogSetLevel(t *testing.T) { logger := GetLogger() buf := &bytes.Buffer{} From 84793092bd753a9306430ef223ddeeea61887b23 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Fri, 27 Oct 2023 15:12:30 +0200 Subject: [PATCH 2/2] SNOW-878073 Add MaxRetryCount as configurable parameter (#948) --- auth.go | 2 +- chunk_downloader.go | 4 +- connection.go | 1 + dsn.go | 17 +++++- dsn_test.go | 27 ++++++++++ ocsp.go | 6 +-- restful.go | 12 +++-- restful_test.go | 10 ++-- retry.go | 10 ++-- retry_test.go | 128 +++++++++++++++++++------------------------- 10 files changed, 121 insertions(+), 96 deletions(-) diff --git a/auth.go b/auth.go index c3894e43b..9493459f9 100644 --- a/auth.go +++ b/auth.go @@ -226,7 +226,7 @@ func postAuth( fullURL := sr.getFullURL(loginRequestPath, params) logger.Infof("full URL: %v", fullURL) - resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout) + resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, sr.MaxRetryCount) if err != nil { return nil, err } diff --git a/chunk_downloader.go b/chunk_downloader.go index 6806e5e32..626b5c445 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -264,7 +264,7 @@ func getChunk( if err != nil { return nil, err } - return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.currentTimeProvider, sc.cfg).execute() + return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute() } func (scd *snowflakeChunkDownloader) startArrowBatches() error { @@ -638,7 +638,7 @@ func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error if err != nil { return err } - res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, defaultTimeProvider, nil).execute() + res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, 0, defaultTimeProvider, nil).execute() if err != nil { return err } diff --git a/connection.go b/connection.go index 087ac7215..5b39d1460 100644 --- a/connection.go +++ b/connection.go @@ -789,6 +789,7 @@ func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, err TokenAccessor: tokenAccessor, LoginTimeout: sc.cfg.LoginTimeout, RequestTimeout: sc.cfg.RequestTimeout, + MaxRetryCount: sc.cfg.MaxRetryCount, FuncPost: postRestful, FuncGet: getRestful, FuncAuthPost: postAuthRestful, diff --git a/dsn.go b/dsn.go index b3a8212f7..1ccebba05 100644 --- a/dsn.go +++ b/dsn.go @@ -19,12 +19,13 @@ import ( ) const ( - defaultClientTimeout = 300 * time.Second // Timeout for network round trip + read out http response + defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth - defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout + defaultLoginTimeout = 300 * time.Second // Timeout for retry for login EXCLUDING clientTimeout defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout defaultJWTTimeout = 60 * time.Second defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login + defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries defaultDomain = ".snowflakecomputing.com" ) @@ -74,6 +75,7 @@ type Config struct { ClientTimeout time.Duration // Timeout for network round trip + read out http response JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place ExternalBrowserTimeout time.Duration // Timeout for external browser login + MaxRetryCount int // Specifies how many times non-periodic HTTP request can be retried Application string // application name. InsecureMode bool // driver doesn't check certificate revocation status @@ -205,6 +207,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout { params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10)) } + if cfg.MaxRetryCount != defaultMaxRetryCount { + params.Add("maxRetryCount", strconv.Itoa(cfg.MaxRetryCount)) + } if cfg.Application != clientType { params.Add("application", cfg.Application) } @@ -471,6 +476,9 @@ func fillMissingConfigParameters(cfg *Config) error { if cfg.ExternalBrowserTimeout == 0 { cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout } + if cfg.MaxRetryCount == 0 { + cfg.MaxRetryCount = defaultMaxRetryCount + } if strings.Trim(cfg.Application, " ") == "" { cfg.Application = clientType } @@ -642,6 +650,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return err } + case "maxRetryCount": + cfg.MaxRetryCount, err = strconv.Atoi(value) + if err != nil { + return err + } case "application": cfg.Application = value case "authenticator": diff --git a/dsn_test.go b/dsn_test.go index e49207bd5..f939f866c 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -336,6 +336,23 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, IncludeRetryReason: ConfigBoolTrue, + MaxRetryCount: defaultMaxRetryCount, + }, + ocspMode: ocspModeFailOpen, + }, + { + dsn: "u:p@a?database=d&maxRetryCount=20", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.snowflakecomputing.com", Port: 443, + Database: "d", Schema: "", + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + IncludeRetryReason: ConfigBoolTrue, + MaxRetryCount: 20, }, ocspMode: ocspModeFailOpen, }, @@ -1239,6 +1256,16 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + IncludeRetryReason: ConfigBoolFalse, + MaxRetryCount: 30, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?includeRetryReason=false&maxRetryCount=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, { cfg: &Config{ User: "u", diff --git a/ocsp.go b/ocsp.go index 297a999db..8fc7cef30 100644 --- a/ocsp.go +++ b/ocsp.go @@ -358,7 +358,7 @@ func checkOCSPCacheServer( ocspS *ocspStatus) { var respd map[string][]interface{} headers := make(map[string]string) - res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider, nil).execute() + res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultMaxRetryCount, defaultTimeProvider, nil).execute() if err != nil { logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) return nil, &ocspStatus{ @@ -413,7 +413,7 @@ func retryOCSP( } res, err := newRetryHTTP( ctx, client, req, ocspHost, headers, - totalTimeout*time.Duration(multiplier), defaultTimeProvider, nil).doPost().setBody(reqBody).execute() + totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).doPost().setBody(reqBody).execute() if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedSubmit, @@ -466,7 +466,7 @@ func fallbackRetryOCSPToGETRequest( multiplier = 3 // up to 3 times for Fail Close mode } res, err := newRetryHTTP(ctx, client, req, ocspHost, headers, - totalTimeout*time.Duration(multiplier), defaultTimeProvider, nil).execute() + totalTimeout*time.Duration(multiplier), defaultMaxRetryCount, defaultTimeProvider, nil).execute() if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedSubmit, diff --git a/restful.go b/restful.go index c94e28636..777d94df4 100644 --- a/restful.go +++ b/restful.go @@ -44,7 +44,7 @@ const ( type ( funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error) - funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error) + funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, int) (*http.Response, error) bodyCreatorType func() ([]byte, error) ) @@ -58,6 +58,7 @@ type snowflakeRestful struct { Protocol string LoginTimeout time.Duration // Login timeout RequestTimeout time.Duration // request timeout + MaxRetryCount int Client *http.Client JWTClient *http.Client @@ -165,7 +166,7 @@ func postRestful( currentTimeProvider currentTimeProvider, cfg *Config) ( *http.Response, error) { - return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider, cfg). + return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, currentTimeProvider, cfg). doPost(). setBody(body). execute() @@ -178,7 +179,7 @@ func getRestful( headers map[string]string, timeout time.Duration) ( *http.Response, error) { - return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil).execute() + return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, sr.MaxRetryCount, defaultTimeProvider, nil).execute() } func postAuthRestful( @@ -187,9 +188,10 @@ func postAuthRestful( fullURL *url.URL, headers map[string]string, bodyCreator bodyCreatorType, - timeout time.Duration) ( + timeout time.Duration, + maxRetryCount int) ( *http.Response, error) { - return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil). + return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, maxRetryCount, defaultTimeProvider, nil). doPost(). setBodyCreator(bodyCreator). execute() diff --git a/restful_test.go b/restful_test.go index 66b3b305b..0eb5fc445 100644 --- a/restful_test.go +++ b/restful_test.go @@ -22,7 +22,7 @@ func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str }, errors.New("failed to run post method") } -func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { +func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -43,7 +43,7 @@ func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.U }, nil } -func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { +func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -57,14 +57,14 @@ func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.UR }, nil } -func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { +func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { +func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusInsufficientStorage, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str }, nil } -func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { +func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ int) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, diff --git a/retry.go b/retry.go index 725a8fc46..8995b9c97 100644 --- a/retry.go +++ b/retry.go @@ -17,11 +17,6 @@ import ( "time" ) -const ( - // defaultMaxRetryCount specifies maximum number of subsequent retries - defaultMaxRetryCount = 7 -) - type waitAlgo struct { mutex *sync.Mutex // required for *rand.Rand usage random *rand.Rand @@ -248,6 +243,7 @@ type retryHTTP struct { headers map[string]string bodyCreator bodyCreatorType timeout time.Duration + maxRetryCount int currentTimeProvider currentTimeProvider cfg *Config } @@ -258,6 +254,7 @@ func newRetryHTTP(ctx context.Context, fullURL *url.URL, headers map[string]string, timeout time.Duration, + maxRetryCount int, currentTimeProvider currentTimeProvider, cfg *Config) *retryHTTP { instance := retryHTTP{} @@ -268,6 +265,7 @@ func newRetryHTTP(ctx context.Context, instance.fullURL = fullURL instance.headers = headers instance.timeout = timeout + instance.maxRetryCount = maxRetryCount instance.bodyCreator = emptyBodyCreator instance.currentTimeProvider = currentTimeProvider instance.cfg = cfg @@ -341,7 +339,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) // if any timeout is set totalTimeout -= time.Duration(sleepTime * float64(time.Second)) - if totalTimeout <= 0 || retryCounter >= defaultMaxRetryCount { + if totalTimeout <= 0 || retryCounter > r.maxRetryCount { if err != nil { return nil, err } diff --git a/retry_test.go b/retry_test.go index 7a089ecbd..5815d66b3 100644 --- a/retry_test.go +++ b/retry_test.go @@ -223,20 +223,14 @@ func TestRetryQuerySuccess(t *testing.T) { t: t, } urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolTrue}).doPost().setBody([]byte{0}).execute() - if err != nil { - t.Fatal("failed to run retry") - } + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolTrue}).doPost().setBody([]byte{0}).execute() + assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatal("failed to parse the URL") - } + assertNilF(t, err, "failed to parse the test URL") retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) @@ -272,20 +266,14 @@ func TestRetryQuerySuccessWithRetryReasonDisabled(t *testing.T) { t: t, } urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolFalse}).doPost().setBody([]byte{0}).execute() - if err != nil { - t.Fatal("failed to run retry") - } + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolFalse}).doPost().setBody([]byte{0}).execute() + assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatal("failed to parse the URL") - } + assertNilF(t, err, "failed to parse the test URL") retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) @@ -318,20 +306,14 @@ func TestRetryQuerySuccessWithTimeout(t *testing.T) { t: t, } urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), nil).doPost().setBody([]byte{0}).execute() - if err != nil { - t.Fatal("failed to run retry") - } + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, constTimeProvider(123456), nil).doPost().setBody([]byte{0}).execute() + assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatal("failed to parse the URL") - } + assertNilF(t, err, "failed to parse the test URL") retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) @@ -341,33 +323,52 @@ func TestRetryQuerySuccessWithTimeout(t *testing.T) { } } -func TestRetryQueryFail(t *testing.T) { +func TestRetryQueryFailWithTimeout(t *testing.T) { logger.Info("Retry N times until there is a timeout and Fail") client := &fakeHTTPClient{ statusCode: http.StatusTooManyRequests, success: false, } urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey) - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 15*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() - if err == nil { - t.Fatal("should fail to run retry") + emptyRequest, urlPtr, make(map[string]string), 15*time.Second, 100, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() + assertNotNilF(t, err, "should fail to run retry") + var values url.Values + values, err = url.ParseQuery(urlPtr.RawQuery) + assertNilF(t, err, fmt.Sprintf("failed to parse the URL: %v", err)) + retry, err := strconv.Atoi(values.Get(retryCountKey)) + assertNilF(t, err, fmt.Sprintf("failed to get retry counter: %v", err)) + if retry < 2 { + t.Fatalf("not enough retries: %v", retry) } +} + +func TestRetryQueryFailWithMaxRetryCount(t *testing.T) { + maxRetryCount := 3 + logger.Info("Retry 3 times until retry reaches MaxRetryCount and Fail") + client := &fakeHTTPClient{ + statusCode: http.StatusTooManyRequests, + success: false, + } + urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey) + assertNilF(t, err, "failed to parse the test URL") + _, err = newRetryHTTP(context.Background(), + client, + emptyRequest, urlPtr, make(map[string]string), 15*time.Hour, maxRetryCount, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() + assertNotNilF(t, err, "should fail to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { t.Fatalf("failed to parse the URL: %v", err) } - retry, err := strconv.Atoi(values.Get(retryCountKey)) + retryCount, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { t.Fatalf("failed to get retry counter: %v", err) } - if retry < 2 { - t.Fatalf("not enough retry counter: %v", retry) + if retryCount < 3 { + t.Fatalf("not enough retries: %v; expected %v", retryCount, maxRetryCount) } } @@ -394,35 +395,26 @@ func TestRetryLoginRequest(t *testing.T) { }, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() - if err != nil { - t.Fatal("failed to run retry") - } + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() + assertNilF(t, err, "failed to run retry") var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatalf("failed to parse the URL: %v", err) - } + assertNilF(t, err, "failed to parse the test URL") if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) } logger.Info("Retry N times for timeouts and Fail") client = &fakeHTTPClient{ - cnt: 10, success: false, timeout: true, } _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 10*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() - if err == nil { - t.Fatal("should fail to run retry") - } + emptyRequest, urlPtr, make(map[string]string), 5*time.Second, 3, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() + assertNotNilF(t, err, "should fail to run retry") values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { t.Fatalf("failed to parse the URL: %v", err) @@ -440,9 +432,7 @@ func TestRetryAuthLoginRequest(t *testing.T) { timeout: true, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") execID := 0 bodyCreator := func() ([]byte, error) { execID++ @@ -450,10 +440,8 @@ func TestRetryAuthLoginRequest(t *testing.T) { } _, err = newRetryHTTP(context.Background(), client, - http.NewRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBodyCreator(bodyCreator).execute() - if err != nil { - t.Fatal("failed to run retry") - } + http.NewRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, nil).doPost().setBodyCreator(bodyCreator).execute() + assertNilF(t, err, "failed to run retry") if lastReqBody := string(client.reqBody); lastReqBody != "execID: 3" { t.Fatalf("body should be updated on each request, expected: execID: 3, last body: %v", lastReqBody) } @@ -466,20 +454,16 @@ func TestLoginRetry429(t *testing.T) { statusCode: 429, } urlPtr, err := url.Parse("https://fakeaccountretrylogin.snowflakecomputing.com:443/login-request?request_id=testid") - if err != nil { - t.Fatal("failed to parse the test URL") - } + assertNilF(t, err, "failed to parse the test URL") + _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX - if err != nil { - t.Fatal("failed to run retry") - } + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, 3, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX + assertNilF(t, err, "failed to run retry") + var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) - if err != nil { - t.Fatalf("failed to parse the URL: %v", err) - } + assertNilF(t, err, fmt.Sprintf("failed to parse the URL: %v", err)) if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) }