diff --git a/authokta.go b/authokta.go index 994b51c13..3e3f3c518 100644 --- a/authokta.go +++ b/authokta.go @@ -216,7 +216,7 @@ func postAuthSAML( fullURL := sr.getFullURL(authenticatorRequestPath, params) logger.Infof("fullURL: %v", fullURL) - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider, nil) if err != nil { return nil, err } @@ -274,7 +274,7 @@ func postAuthOKTA( if err != nil { return nil, err } - resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider, nil) if err != nil { return nil, err } diff --git a/chunk_downloader.go b/chunk_downloader.go index 1a298cf3e..a32fd1628 100644 --- a/chunk_downloader.go +++ b/chunk_downloader.go @@ -264,7 +264,7 @@ func getChunk( if err != nil { return nil, err } - return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.currentTimeProvider).execute() + return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.currentTimeProvider, sc.cfg).execute() } func (scd *snowflakeChunkDownloader) startArrowBatches() error { @@ -636,7 +636,7 @@ func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error if err != nil { return err } - res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, defaultTimeProvider).execute() + res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, defaultTimeProvider, nil).execute() if err != nil { return err } diff --git a/client.go b/client.go index a5d7f747c..dff709bed 100644 --- a/client.go +++ b/client.go @@ -35,5 +35,5 @@ func (cli *httpClient) Post( timeout time.Duration, raise4xx bool, currentTimeProvider currentTimeProvider) (*http.Response, error) { - return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider) + return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider, nil) } diff --git a/dsn.go b/dsn.go index d5fa2ab73..92fd472e0 100644 --- a/dsn.go +++ b/dsn.go @@ -99,6 +99,8 @@ type Config struct { ClientStoreTemporaryCredential ConfigBool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux. DisableQueryContextCache bool // Should HTAP query context cache be disabled + + IncludeRetryReason ConfigBool // Should retried request contain retry reason } // Validate enables testing if config is correct. @@ -235,6 +237,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.DisableQueryContextCache { params.Add("disableQueryContextCache", "true") } + if cfg.IncludeRetryReason == ConfigBoolFalse { + params.Add("includeRetryReason", "false") + } params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) @@ -473,6 +478,10 @@ func fillMissingConfigParameters(cfg *Config) error { cfg.ValidateDefaultParameters = ConfigBoolTrue } + if cfg.IncludeRetryReason == configBoolNotSet { + cfg.IncludeRetryReason = ConfigBoolTrue + } + if strings.HasSuffix(cfg.Host, defaultDomain) && len(cfg.Host) == len(defaultDomain) { return &SnowflakeError{ Number: ErrCodeFailedToParseHost, @@ -714,6 +723,17 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } cfg.DisableQueryContextCache = b + case "includeRetryReason": + var vv bool + vv, err = strconv.ParseBool(value) + if err != nil { + return + } + if vv { + cfg.IncludeRetryReason = ConfigBoolTrue + } else { + cfg.IncludeRetryReason = ConfigBoolFalse + } default: if cfg.Params == nil { cfg.Params = make(map[string]*string) diff --git a/dsn_test.go b/dsn_test.go index 0086e2cf5..393eda8cf 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -39,6 +39,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -54,6 +55,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -67,6 +69,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -81,6 +84,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -96,6 +100,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -111,6 +116,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -125,6 +131,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -139,6 +146,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -153,6 +161,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -168,6 +177,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -183,6 +193,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -198,6 +209,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: errEmptyPassword(), @@ -213,6 +225,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: errEmptyUsername(), @@ -228,6 +241,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: errEmptyAccount(), @@ -243,6 +257,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -257,6 +272,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -271,6 +287,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -285,6 +302,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -302,6 +320,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, }, @@ -316,6 +335,7 @@ func TestParseDSN(t *testing.T) { ValidateDefaultParameters: ConfigBoolTrue, ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, }, @@ -331,6 +351,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, }, @@ -344,6 +365,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{ @@ -364,6 +386,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeInsecure, err: nil, @@ -380,6 +403,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -431,6 +455,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -446,6 +471,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -466,6 +492,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -481,6 +508,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: &SnowflakeError{Number: ErrCodePrivateKeyParseError}, @@ -495,6 +523,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -509,6 +538,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailClosed, err: nil, @@ -523,6 +553,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeInsecure, err: nil, @@ -536,6 +567,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -549,6 +581,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -562,12 +595,13 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, }, { - dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300&jwtClientTimeout=45", + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&clientTimeout=300&jwtClientTimeout=45&includeRetryReason=false", config: &Config{ Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, @@ -576,6 +610,7 @@ func TestParseDSN(t *testing.T) { JWTClientTimeout: 45 * time.Second, ExternalBrowserTimeout: defaultExternalBrowserTimeout, DisableQueryContextCache: false, + IncludeRetryReason: ConfigBoolFalse, }, ocspMode: ocspModeFailOpen, err: nil, @@ -590,6 +625,7 @@ func TestParseDSN(t *testing.T) { JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, TmpDirPath: "/tmp", + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -604,6 +640,21 @@ func TestParseDSN(t *testing.T) { JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, DisableQueryContextCache: true, + IncludeRetryReason: ConfigBoolTrue, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&includeRetryReason=true", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, }, ocspMode: ocspModeFailOpen, err: nil, @@ -622,6 +673,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, Authenticator: at, }, ocspMode: ocspModeFailOpen, @@ -641,6 +693,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, Authenticator: at, }, ocspMode: ocspModeFailOpen, @@ -660,6 +713,7 @@ func TestParseDSN(t *testing.T) { ClientTimeout: defaultClientTimeout, JWTClientTimeout: defaultJWTClientTimeout, ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, Authenticator: at, }, ocspMode: ocspModeFailOpen, @@ -761,6 +815,9 @@ func TestParseDSN(t *testing.T) { if test.config.DisableQueryContextCache != cfg.DisableQueryContextCache { t.Fatalf("%v: Failed to match DisableQueryContextCache. expected: %v, got: %v", i, test.config.DisableQueryContextCache, cfg.DisableQueryContextCache) } + if test.config.IncludeRetryReason != cfg.IncludeRetryReason { + t.Fatalf("%v: Failed to match IncludeRetryReason. expected: %v, got: %v", i, test.config.IncludeRetryReason, cfg.IncludeRetryReason) + } case test.err != nil: driverErrE, okE := test.err.(*SnowflakeError) driverErrG, okG := err.(*SnowflakeError) @@ -1157,9 +1214,28 @@ func TestDSN(t *testing.T) { Password: "p", Account: "a.b.c", DisableQueryContextCache: true, + IncludeRetryReason: ConfigBoolTrue, }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?disableQueryContextCache=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + IncludeRetryReason: ConfigBoolFalse, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?includeRetryReason=false&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + IncludeRetryReason: ConfigBoolTrue, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, } for _, test := range testcases { t.Run(test.dsn, func(t *testing.T) { diff --git a/heartbeat.go b/heartbeat.go index 60c4bf10e..9d44bb28c 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error { fullURL := hc.restful.getFullURL(heartBeatPath, params) timeout := hc.restful.RequestTimeout - resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider) + resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider, nil) if err != nil { return err } diff --git a/ocsp.go b/ocsp.go index c809a24c6..6feb71692 100644 --- a/ocsp.go +++ b/ocsp.go @@ -358,7 +358,7 @@ func checkOCSPCacheServer( ocspS *ocspStatus) { var respd map[string][]interface{} headers := make(map[string]string) - res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider).execute() + res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider, nil).execute() if err != nil { logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v", err) return nil, &ocspStatus{ @@ -413,7 +413,7 @@ func retryOCSP( } res, err := newRetryHTTP( ctx, client, req, ocspHost, headers, - totalTimeout*time.Duration(multiplier), defaultTimeProvider).doPost().setBody(reqBody).execute() + totalTimeout*time.Duration(multiplier), defaultTimeProvider, nil).doPost().setBody(reqBody).execute() if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedSubmit, @@ -466,7 +466,7 @@ func fallbackRetryOCSPToGETRequest( multiplier = 3 // up to 3 times for Fail Close mode } res, err := newRetryHTTP(ctx, client, req, ocspHost, headers, - totalTimeout*time.Duration(multiplier), defaultTimeProvider).execute() + totalTimeout*time.Duration(multiplier), defaultTimeProvider, nil).execute() if err != nil { return ocspRes, ocspResBytes, &ocspStatus{ code: ocspFailedSubmit, diff --git a/restful.go b/restful.go index 34297c1d9..f6948c4c3 100644 --- a/restful.go +++ b/restful.go @@ -43,7 +43,7 @@ const ( type ( funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) - funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error) + funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider, *Config) (*http.Response, error) funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error) bodyCreatorType func() ([]byte, error) ) @@ -163,9 +163,10 @@ func postRestful( body []byte, timeout time.Duration, raise4XX bool, - currentTimeProvider currentTimeProvider) ( + currentTimeProvider currentTimeProvider, + cfg *Config) ( *http.Response, error) { - return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider). + return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider, cfg). doPost(). setBody(body). doRaise4XX(raise4XX). @@ -179,7 +180,7 @@ func getRestful( headers map[string]string, timeout time.Duration) ( *http.Response, error) { - return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider).execute() + return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil).execute() } func postAuthRestful( @@ -191,7 +192,7 @@ func postAuthRestful( timeout time.Duration, raise4XX bool) ( *http.Response, error) { - return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider). + return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil). doPost(). setBodyCreator(bodyCreator). doRaise4XX(raise4XX). @@ -242,7 +243,7 @@ func postRestfulQueryHelper( var resp *http.Response fullURL := sr.getFullURL(queryRequestPath, params) - resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider) + resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider, cfg) if err != nil { return nil, err } @@ -334,7 +335,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider, nil) if err != nil { return err } @@ -393,7 +394,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time return err } - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider, nil) if err != nil { return err } @@ -465,7 +466,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time return err } - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider, nil) if err != nil { return err } diff --git a/restful_test.go b/restful_test.go index f23cacd97..483048b55 100644 --- a/restful_test.go +++ b/restful_test.go @@ -15,7 +15,7 @@ import ( "time" ) -func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -29,14 +29,14 @@ func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[stri }, errors.New("failed to run post method") } -func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -50,7 +50,7 @@ func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.UR }, nil } -func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, @@ -71,7 +71,7 @@ func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.UR }, nil } -func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR }, nil } -func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map }, nil } -func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID, if err != nil { return err } - resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider, nil) if err != nil { return err } @@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) { accessor := getSimpleTokenAccessor() oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100) newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200) - postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { + postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) { t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken)) } diff --git a/retry.go b/retry.go index f465b6609..8fa90c8df 100644 --- a/retry.go +++ b/retry.go @@ -152,10 +152,15 @@ func (retryReasonUpdater *transientRetryReasonUpdater) replaceOrAdd(_ int) *url. return retryReasonUpdater.url } -func newRetryReasonUpdater(url *url.URL) retryReasonUpdater { +func newRetryReasonUpdater(url *url.URL, cfg *Config) retryReasonUpdater { + // not a query request if !isQueryRequest(url) { return &transientRetryReasonUpdater{url} } + // implicitly disabled retry reason + if cfg != nil && cfg.IncludeRetryReason == ConfigBoolFalse { + return &transientRetryReasonUpdater{url} + } return &retryReasonUpdate{url} } @@ -224,6 +229,7 @@ type retryHTTP struct { timeout time.Duration raise4XX bool currentTimeProvider currentTimeProvider + cfg *Config } func newRetryHTTP(ctx context.Context, @@ -232,7 +238,8 @@ func newRetryHTTP(ctx context.Context, fullURL *url.URL, headers map[string]string, timeout time.Duration, - currentTimeProvider currentTimeProvider) *retryHTTP { + currentTimeProvider currentTimeProvider, + cfg *Config) *retryHTTP { instance := retryHTTP{} instance.ctx = ctx instance.client = client @@ -244,6 +251,7 @@ func newRetryHTTP(ctx context.Context, instance.bodyCreator = emptyBodyCreator instance.raise4XX = false instance.currentTimeProvider = currentTimeProvider + instance.cfg = cfg return &instance } @@ -346,7 +354,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } r.fullURL = retryCountUpdater.replaceOrAdd(retryCounter) if retryReasonUpdater == nil { - retryReasonUpdater = newRetryReasonUpdater(r.fullURL) + retryReasonUpdater = newRetryReasonUpdater(r.fullURL, r.cfg) } retryReason := 0 if res != nil { diff --git a/retry_test.go b/retry_test.go index e8590a21b..ea9ee18b9 100644 --- a/retry_test.go +++ b/retry_test.go @@ -227,14 +227,63 @@ func TestRetryQuerySuccess(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456)).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolTrue}).doPost().setBody([]byte{0}).execute() if err != nil { t.Fatal("failed to run retry") } var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { - t.Fatal("failed to fail to parse the URL") + t.Fatal("failed to parse the URL") + } + retry, err := strconv.Atoi(values.Get(retryCountKey)) + if err != nil { + t.Fatalf("failed to get retry counter: %v", err) + } + if retry < 2 { + t.Fatalf("not enough retry counter: %v", retry) + } +} + +func TestRetryQuerySuccessWithRetryReasonDisabled(t *testing.T) { + logger.Info("Retry N times and Success") + client := &fakeHTTPClient{ + cnt: 3, + success: true, + statusCode: 429, + expectedQueryParams: map[int]map[string]string{ + 0: { + "retryCount": "", + "retryReason": "", + "clientStartTime": "", + }, + 1: { + "retryCount": "1", + "retryReason": "", + "clientStartTime": "123456", + }, + 2: { + "retryCount": "2", + "retryReason": "", + "clientStartTime": "123456", + }, + }, + t: t, + } + urlPtr, err := url.Parse("https://fakeaccountretrysuccess.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey + "=testid") + if err != nil { + t.Fatal("failed to parse the test URL") + } + _, err = newRetryHTTP(context.TODO(), + client, + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), &Config{IncludeRetryReason: ConfigBoolFalse}).doPost().setBody([]byte{0}).execute() + if err != nil { + t.Fatal("failed to run retry") + } + var values url.Values + values, err = url.ParseQuery(urlPtr.RawQuery) + if err != nil { + t.Fatal("failed to parse the URL") } retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { @@ -273,14 +322,14 @@ func TestRetryQuerySuccessWithTimeout(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456)).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, constTimeProvider(123456), nil).doPost().setBody([]byte{0}).execute() if err != nil { t.Fatal("failed to run retry") } var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { - t.Fatal("failed to fail to parse the URL") + t.Fatal("failed to parse the URL") } retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { @@ -303,14 +352,14 @@ func TestRetryQueryFail(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err == nil { t.Fatal("should fail to run retry") } var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) + t.Fatalf("failed to parse the URL: %v", err) } retry, err := strconv.Atoi(values.Get(retryCountKey)) if err != nil { @@ -349,14 +398,14 @@ func TestRetryLoginRequest(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err != nil { t.Fatal("failed to run retry") } var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) + t.Fatalf("failed to parse the URL: %v", err) } if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) @@ -369,13 +418,13 @@ func TestRetryLoginRequest(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 10*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 10*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err == nil { t.Fatal("should fail to run retry") } values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) + t.Fatalf("failed to parse the URL: %v", err) } if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) @@ -400,7 +449,7 @@ func TestRetryAuthLoginRequest(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - http.NewRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBodyCreator(bodyCreator).execute() + http.NewRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBodyCreator(bodyCreator).execute() if err != nil { t.Fatal("failed to run retry") } @@ -421,14 +470,14 @@ func TestLoginRetry429(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX if err != nil { t.Fatal("failed to run retry") } var values url.Values values, err = url.ParseQuery(urlPtr.RawQuery) if err != nil { - t.Fatalf("failed to fail to parse the URL: %v", err) + t.Fatalf("failed to parse the URL: %v", err) } if values.Get(retryCountKey) != "" { t.Fatalf("no retry counter should be attached: %v", retryCountKey) diff --git a/telemetry.go b/telemetry.go index 1adb085a4..911aba229 100644 --- a/telemetry.go +++ b/telemetry.go @@ -97,7 +97,7 @@ func (st *snowflakeTelemetry) sendBatch() error { } resp, err := st.sr.FuncPost(context.Background(), st.sr, st.sr.getFullURL(telemetryPath, nil), headers, body, - defaultTelemetryTimeout, true, defaultTimeProvider) + defaultTelemetryTimeout, true, defaultTimeProvider, nil) if err != nil { logger.Info("failed to upload metrics to telemetry. err: %v", err) return err diff --git a/telemetry_test.go b/telemetry_test.go index c6a7436a9..5cc1dc783 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -107,7 +107,7 @@ func TestEnableTelemetry(t *testing.T) { }) } -func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) { return nil, errors.New("failed to upload metrics to telemetry") }