From 0ce7e089967cc4b1d912ce272dd6a9c1aa0bab22 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Mon, 2 Oct 2023 11:03:25 +0200 Subject: [PATCH 01/15] SNOW-878073 Refactor retry policy to support HTTP 503 & 429 --- auth.go | 2 +- authokta.go | 4 +-- client.go | 5 ++-- client_test.go | 2 +- heartbeat.go | 2 +- restful.go | 18 +++++------ restful_test.go | 28 ++++++++--------- retry.go | 76 ++++++++++++++--------------------------------- retry_test.go | 2 +- telemetry.go | 2 +- telemetry_test.go | 2 +- 11 files changed, 54 insertions(+), 89 deletions(-) diff --git a/auth.go b/auth.go index fa2f651c6..8aa47b4f0 100644 --- a/auth.go +++ b/auth.go @@ -226,7 +226,7 @@ func postAuth( fullURL := sr.getFullURL(loginRequestPath, params) logger.Infof("full URL: %v", fullURL) - resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true) + resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout) if err != nil { return nil, err } diff --git a/authokta.go b/authokta.go index 994b51c13..2a490a6dc 100644 --- a/authokta.go +++ b/authokta.go @@ -216,7 +216,7 @@ func postAuthSAML( fullURL := sr.getFullURL(authenticatorRequestPath, params) logger.Infof("fullURL: %v", fullURL) - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider) if err != nil { return nil, err } @@ -274,7 +274,7 @@ func postAuthOKTA( if err != nil { return nil, err } - resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider) if err != nil { return nil, err } diff --git a/client.go b/client.go index a5d7f747c..edfb9543f 100644 --- a/client.go +++ b/client.go @@ -12,7 +12,7 @@ import ( // InternalClient is implemented by HTTPClient type InternalClient interface { Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error) - Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error) + Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error) } type httpClient struct { @@ -33,7 +33,6 @@ func (cli *httpClient) Post( headers map[string]string, body []byte, timeout time.Duration, - raise4xx bool, currentTimeProvider currentTimeProvider) (*http.Response, error) { - return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider) + return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, currentTimeProvider) } diff --git a/client_test.go b/client_test.go index 6ef66ef30..2e8beaa66 100644 --- a/client_test.go +++ b/client_test.go @@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) { t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests) } - resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider) + resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider) if err != nil || resp.StatusCode != 200 { t.Fail() } diff --git a/heartbeat.go b/heartbeat.go index 60c4bf10e..a4b6d6254 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error { fullURL := hc.restful.getFullURL(heartBeatPath, params) timeout := hc.restful.RequestTimeout - resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider) + resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider) if err != nil { return err } diff --git a/restful.go b/restful.go index 34297c1d9..1ec6119af 100644 --- a/restful.go +++ b/restful.go @@ -43,8 +43,8 @@ const ( type ( funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) - funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error) - funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error) + funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error) + funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error) bodyCreatorType func() ([]byte, error) ) @@ -162,13 +162,11 @@ func postRestful( headers map[string]string, body []byte, timeout time.Duration, - raise4XX bool, currentTimeProvider currentTimeProvider) ( *http.Response, error) { return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider). doPost(). setBody(body). - doRaise4XX(raise4XX). execute() } @@ -188,13 +186,11 @@ func postAuthRestful( fullURL *url.URL, headers map[string]string, bodyCreator bodyCreatorType, - timeout time.Duration, - raise4XX bool) ( + timeout time.Duration) ( *http.Response, error) { return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider). doPost(). setBodyCreator(bodyCreator). - doRaise4XX(raise4XX). execute() } @@ -242,7 +238,7 @@ func postRestfulQueryHelper( var resp *http.Response fullURL := sr.getFullURL(queryRequestPath, params) - resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider) + resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider) if err != nil { return nil, err } @@ -334,7 +330,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider) if err != nil { return err } @@ -393,7 +389,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time return err } - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider) if err != nil { return err } @@ -465,7 +461,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time return err } - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider) if err != nil { return err } diff --git a/restful_test.go b/restful_test.go index f23cacd97..7aa32a7d4 100644 --- a/restful_test.go +++ b/restful_test.go @@ -15,63 +15,63 @@ import ( "time" ) -func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, errors.New("failed to run post method") } -func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, errors.New("failed to run post method") } -func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusInsufficientStorage, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR }, nil } -func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str }, nil } -func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map }, nil } -func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID, if err != nil { return err } - resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider) if err != nil { return err } @@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) { accessor := getSimpleTokenAccessor() oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100) newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200) - postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { + postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) { t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken)) } diff --git a/retry.go b/retry.go index f465b6609..5a994ed53 100644 --- a/retry.go +++ b/retry.go @@ -5,13 +5,12 @@ package gosnowflake import ( "bytes" "context" - "crypto/x509" "fmt" + "golang.org/x/exp/slices" "io" "math/rand" "net/http" "net/url" - "runtime" "strconv" "strings" "sync" @@ -20,6 +19,15 @@ import ( var random *rand.Rand +var endpointsEligibleForRetry = []string{ + loginRequestPath, + queryRequestPath, + tokenRequestPath, + authenticatorRequestPath, +} + +var statusCodesEligibleForRetry = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} + func init() { random = rand.New(rand.NewSource(time.Now().UnixNano())) } @@ -222,7 +230,6 @@ type retryHTTP struct { headers map[string]string bodyCreator bodyCreatorType timeout time.Duration - raise4XX bool currentTimeProvider currentTimeProvider } @@ -242,16 +249,10 @@ func newRetryHTTP(ctx context.Context, instance.headers = headers instance.timeout = timeout instance.bodyCreator = emptyBodyCreator - instance.raise4XX = false instance.currentTimeProvider = currentTimeProvider return &instance } -func (r *retryHTTP) doRaise4XX(raise4XX bool) *retryHTTP { - r.raise4XX = raise4XX - return r -} - func (r *retryHTTP) doPost() *retryHTTP { r.method = "POST" return r @@ -298,27 +299,19 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { req.Header.Set(k, v) } res, err = r.client.Do(req) + // check if it can retry. + retryable, err := r.isRetryableError(req, res, err) + if !retryable { + return res, err + } if err != nil { - // check if it can retry. - doExit, err := r.isRetryableError(err) - if doExit { - return res, err - } - // cannot just return 4xx and 5xx status as the error can be sporadic. run often helps. logger.WithContext(r.ctx).Warningf( - "failed http connection. no response is returned. err: %v. retrying...\n", err) + "failed http connection. err: %v. retrying...\n", err) } else { - if res.StatusCode == http.StatusOK || r.raise4XX && res != nil && res.StatusCode >= 400 && res.StatusCode < 500 && res.StatusCode != 429 { - // exit if success - // or - // abort connection if raise4XX flag is enabled and the range of HTTP status code are 4XX. - // This is currently used for Snowflake login. The caller must generate an error object based on HTTP status. - break - } logger.WithContext(r.ctx).Warningf( "failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode) - res.Body.Close() } + res.Body.Close() // uses decorrelated jitter backoff sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime) @@ -366,36 +359,13 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { return res, r.ctx.Err() } } - return res, err } -func (r *retryHTTP) isRetryableError(err error) (bool, error) { - urlError, isURLError := err.(*url.Error) - if isURLError { - // context cancel or timeout - if urlError.Err == context.DeadlineExceeded || urlError.Err == context.Canceled { - return true, urlError.Err - } - if driverError, ok := urlError.Err.(*SnowflakeError); ok { - // Certificate Revoked - if driverError.Number == ErrOCSPStatusRevoked { - return true, err - } - } - if _, ok := urlError.Err.(x509.CertificateInvalidError); ok { - // Certificate is invalid - return true, err - } - if _, ok := urlError.Err.(x509.UnknownAuthorityError); ok { - // Certificate is self-signed - return true, err - } - errString := urlError.Err.Error() - if runtime.GOOS == "darwin" && strings.HasPrefix(errString, "x509:") && strings.HasSuffix(errString, "certificate is expired") { - // Certificate is expired - return true, err - } - +func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { + if res == nil { + return false, err } - return false, err + isRetryableURL := slices.Contains(endpointsEligibleForRetry, req.URL.Path) + isRetryableStatus := slices.Contains(statusCodesEligibleForRetry, res.StatusCode) + return isRetryableURL && isRetryableStatus, err } diff --git a/retry_test.go b/retry_test.go index e8590a21b..9c2b5ac3f 100644 --- a/retry_test.go +++ b/retry_test.go @@ -421,7 +421,7 @@ func TestLoginRetry429(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX if err != nil { t.Fatal("failed to run retry") } diff --git a/telemetry.go b/telemetry.go index 1adb085a4..eb874cc93 100644 --- a/telemetry.go +++ b/telemetry.go @@ -97,7 +97,7 @@ func (st *snowflakeTelemetry) sendBatch() error { } resp, err := st.sr.FuncPost(context.Background(), st.sr, st.sr.getFullURL(telemetryPath, nil), headers, body, - defaultTelemetryTimeout, true, defaultTimeProvider) + defaultTelemetryTimeout, defaultTimeProvider) if err != nil { logger.Info("failed to upload metrics to telemetry. err: %v", err) return err diff --git a/telemetry_test.go b/telemetry_test.go index c6a7436a9..941d1c6d9 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -107,7 +107,7 @@ func TestEnableTelemetry(t *testing.T) { }) } -func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return nil, errors.New("failed to upload metrics to telemetry") } From 438adab1377ead210567bf88f2f7b623fc4f80c9 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 3 Oct 2023 11:43:54 +0200 Subject: [PATCH 02/15] Code review suggestions applied --- retry.go | 6 ++---- util.go | 9 +++++++++ util_test.go | 20 ++++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/retry.go b/retry.go index 5a994ed53..872382fc6 100644 --- a/retry.go +++ b/retry.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "fmt" - "golang.org/x/exp/slices" "io" "math/rand" "net/http" @@ -21,7 +20,6 @@ var random *rand.Rand var endpointsEligibleForRetry = []string{ loginRequestPath, - queryRequestPath, tokenRequestPath, authenticatorRequestPath, } @@ -365,7 +363,7 @@ func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err if res == nil { return false, err } - isRetryableURL := slices.Contains(endpointsEligibleForRetry, req.URL.Path) - isRetryableStatus := slices.Contains(statusCodesEligibleForRetry, res.StatusCode) + isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path) + isRetryableStatus := contains(statusCodesEligibleForRetry, res.StatusCode) return isRetryableURL && isRetryableStatus, err } diff --git a/util.go b/util.go index ade109364..ac7b32170 100644 --- a/util.go +++ b/util.go @@ -250,3 +250,12 @@ type unixTimeProvider struct { func (utp *unixTimeProvider) currentTime() int64 { return time.Now().UnixMilli() } + +func contains[T comparable](s []T, e T) bool { + for _, v := range s { + if v == e { + return true + } + } + return false +} diff --git a/util_test.go b/util_test.go index 2d8576ff9..0f1009220 100644 --- a/util_test.go +++ b/util_test.go @@ -364,3 +364,23 @@ func TestGetFromEnvFailOnMissing(t *testing.T) { t.Error("should report error when there is missing env parameter") } } + +type tcContains[T comparable] struct { + arr []T + e T + expected bool +} + +func TestContains(t *testing.T) { + performContainsTestcase(tcContains[int]{[]int{1, 2, 3, 5}, 4, false}, t) + performContainsTestcase(tcContains[string]{[]string{"a", "b", "C", "F"}, "C", true}, t) + performContainsTestcase(tcContains[int]{[]int{1, 2, 3, 5}, 2, true}, t) + performContainsTestcase(tcContains[string]{[]string{"a", "b", "C", "F"}, "f", false}, t) +} + +func performContainsTestcase[S comparable](tc tcContains[S], t *testing.T) { + result := contains(tc.arr, tc.e) + if result != tc.expected { + t.Errorf("contains failed; arr: %v, e: %v, should be %v but was %v", tc.arr, tc.e, tc.expected, result) + } +} From d8b2fb50e91d11763d45789a051bee40cb9325f2 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 3 Oct 2023 14:49:41 +0200 Subject: [PATCH 03/15] Suggestions applied --- driver_ocsp_test.go | 4 ++-- driver_test.go | 4 ++-- retry.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/driver_ocsp_test.go b/driver_ocsp_test.go index 8ce109a3b..4b331742f 100644 --- a/driver_ocsp_test.go +++ b/driver_ocsp_test.go @@ -634,8 +634,8 @@ func TestOCSPFailClosedResponder404(t *testing.T) { if !ok { t.Fatalf("failed to extract error URL Error: %v", err) } - if !strings.Contains(urlErr.Err.Error(), "HTTP Status: 404") { - t.Fatalf("the root cause is not timeout: %v", urlErr.Err) + if !strings.Contains(urlErr.Err.Error(), "404 Not Found") { + t.Fatalf("the root cause is not timeout: %v", urlErr.Err) } } diff --git a/driver_test.go b/driver_test.go index 1d6662ebe..08f10705f 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1414,7 +1414,7 @@ func TestCancelQuery(t *testing.T) { if err == nil { dbt.Fatal("No timeout error returned") } - if err.Error() != "context deadline exceeded" { + if !strings.Contains(err.Error(), "context deadline exceeded") { dbt.Fatalf("Timeout error mismatch: expect %v, receive %v", context.DeadlineExceeded, err.Error()) } }) @@ -1506,7 +1506,7 @@ func TestLargeSetResultCancel(t *testing.T) { time.Sleep(time.Second) cancel() ret := <-c - if ret.Error() != "context canceled" { + if !strings.Contains(ret.Error(), "context canceled") { t.Fatalf("failed to cancel. err: %v", ret) } close(c) diff --git a/retry.go b/retry.go index 872382fc6..535e695a7 100644 --- a/retry.go +++ b/retry.go @@ -360,7 +360,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { - if res == nil { + if res == nil || req == nil { return false, err } isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path) From 4a0913642b73ab07937f00fdaf82cfe29858e784 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 11 Oct 2023 11:39:00 +0200 Subject: [PATCH 04/15] a --- priv_key_test.go | 13 ++++++------- retry.go | 34 +++++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/priv_key_test.go b/priv_key_test.go index adba84c91..e07430df1 100644 --- a/priv_key_test.go +++ b/priv_key_test.go @@ -16,6 +16,7 @@ import ( "encoding/pem" "fmt" "os" + "strings" "testing" ) @@ -128,14 +129,12 @@ func TestJWTTokenTimeout(t *testing.T) { } defer db.Close() ctx := context.Background() - conn, err := db.Conn(ctx) - if err != nil { - t.Fatalf(err.Error()) + _, err = db.Conn(ctx) + if err == nil || !strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") { + t.Fatalf("expected timeout has not occured") } - defer conn.Close() - invocations := getMocksInvocations(t) - if invocations != 3 { - t.Errorf("Unexpected number of invocations, expected 3, got %v", invocations) + if invocations != 1 { + t.Errorf("Unexpected number of invocations, expected 1, got %v", invocations) } } diff --git a/retry.go b/retry.go index 535e695a7..f250789c6 100644 --- a/retry.go +++ b/retry.go @@ -24,7 +24,14 @@ var endpointsEligibleForRetry = []string{ authenticatorRequestPath, } -var statusCodesEligibleForRetry = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} +var statusCodesEligibleForRetry = []int{ + http.StatusTooManyRequests, + http.StatusServiceUnavailable, + http.StatusBadRequest, + http.StatusForbidden, + http.StatusMethodNotAllowed, + http.StatusRequestTimeout, +} func init() { random = rand.New(rand.NewSource(time.Now().UnixNano())) @@ -189,12 +196,8 @@ type waitAlgo struct { cap time.Duration // maximum wait time } -func randSecondDuration(n time.Duration) time.Duration { - return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second -} - // decorrelated jitter backoff -func (w *waitAlgo) decorr(attempt int, sleep time.Duration) time.Duration { +func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration { w.mutex.Lock() defer w.mutex.Unlock() t := 3*sleep - w.base @@ -207,6 +210,12 @@ func (w *waitAlgo) decorr(attempt int, sleep time.Duration) time.Duration { return w.base } +func (w *waitAlgo) getJitter(currWaitTime int) float64 { + multiplicationFactor := (random.Float64() * 2) - 1 // random float between (-1, 1) + jitterAmount := 0.5 * float64(currWaitTime) * multiplicationFactor + return jitterAmount +} + var defaultWaitAlgo = &waitAlgo{ mutex: &sync.Mutex{}, base: 5 * time.Second, @@ -298,7 +307,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } res, err = r.client.Do(req) // check if it can retry. - retryable, err := r.isRetryableError(req, res, err) + retryable, err := isRetryableError(req, res, err) if !retryable { return res, err } @@ -311,7 +320,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } res.Body.Close() // uses decorrelated jitter backoff - sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime) + sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime) if totalTimeout > 0 { logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) @@ -359,11 +368,14 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } } -func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { +func isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { if res == nil || req == nil { return false, err } isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path) - isRetryableStatus := contains(statusCodesEligibleForRetry, res.StatusCode) - return isRetryableURL && isRetryableStatus, err + return isRetryableURL && isRetryableStatus(res.StatusCode), err +} + +func isRetryableStatus(statusCode int) bool { + return (statusCode >= 500 && statusCode < 600) || contains(statusCodesEligibleForRetry, statusCode) } From 7590dbb03b7e10f458f56d455130f8aee826d1e0 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Fri, 20 Oct 2023 12:51:28 +0200 Subject: [PATCH 05/15] a --- retry.go | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/retry.go b/retry.go index 397b336fa..cc25dea44 100644 --- a/retry.go +++ b/retry.go @@ -16,6 +16,14 @@ import ( "time" ) +var ( + // MaxRetryCount specifies maximum number of subsequent retries + MaxRetryCount = 7 + + // MaxWaitTime specifies maximum wait time throughout subsequent retries + MaxWaitTime = 300 //seconds +) + var random *rand.Rand var endpointsEligibleForRetry = []string{ @@ -196,37 +204,30 @@ func isQueryRequest(url *url.URL) bool { } type waitAlgo struct { - mutex *sync.Mutex // required for random.Int63n - base time.Duration // base wait time - cap time.Duration // maximum wait time + mutex *sync.Mutex // required for random.Int63n + random *rand.Rand } // decorrelated jitter backoff func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration { w.mutex.Lock() defer w.mutex.Unlock() - t := 3*sleep - w.base - switch { - case t > 0: - return durationMin(w.cap, randSecondDuration(t)+w.base) - case t < 0: - return durationMin(w.cap, randSecondDuration(-t)+3*sleep) - } - return w.base + } func (w *waitAlgo) getJitter(currWaitTime int) float64 { - multiplicationFactor := (random.Float64() * 2) - 1 // random float between (-1, 1) - jitterAmount := 0.5 * float64(currWaitTime) * multiplicationFactor + multiplicationFactor := chooseRandomFromValues(w.random, []int{-1, 1}) // random int from [-1, 1] + jitterAmount := 0.5 * float64(currWaitTime) * float64(multiplicationFactor) return jitterAmount } -var defaultWaitAlgo = &waitAlgo{ - mutex: &sync.Mutex{}, - base: 5 * time.Second, - cap: 160 * time.Second, +func chooseRandomFromValues[T any](random *rand.Rand, arr []T) T { + valIdx := random.Intn(len(arr)) + return arr[valIdx] } +var defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random} + type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) type clientInterface interface { @@ -289,7 +290,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { totalTimeout := r.timeout logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout) retryCounter := 0 - sleepTime := time.Duration(0) + sleepTime := time.Duration(1) clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10) var requestGUIDReplacer requestGUIDReplacer @@ -327,7 +328,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { "failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode) } res.Body.Close() - // uses decorrelated jitter backoff + // uses exponential jitter backoff sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime) if totalTimeout > 0 { From b58d00c7ba647ced8931a63c33b7313e8dced381 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 24 Oct 2023 13:03:49 +0200 Subject: [PATCH 06/15] a --- dsn.go | 2 +- restful.go | 2 +- retry.go | 31 +++++++++++++++--------------- retry_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/dsn.go b/dsn.go index 92fd472e0..c365e1506 100644 --- a/dsn.go +++ b/dsn.go @@ -19,7 +19,7 @@ import ( ) const ( - defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response + defaultClientTimeout = 300 * time.Second // Timeout for network round trip + read out http response defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout diff --git a/restful.go b/restful.go index 017728396..c94e28636 100644 --- a/restful.go +++ b/restful.go @@ -46,7 +46,7 @@ type ( funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error) funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error) bodyCreatorType func() ([]byte, error) - +) var emptyBodyCreator = func() ([]byte, error) { return []byte{}, nil diff --git a/retry.go b/retry.go index cc25dea44..d20de2bba 100644 --- a/retry.go +++ b/retry.go @@ -16,12 +16,9 @@ import ( "time" ) -var ( +const ( // MaxRetryCount specifies maximum number of subsequent retries MaxRetryCount = 7 - - // MaxWaitTime specifies maximum wait time throughout subsequent retries - MaxWaitTime = 300 //seconds ) var random *rand.Rand @@ -208,17 +205,21 @@ type waitAlgo struct { random *rand.Rand } -// decorrelated jitter backoff -func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration { +// jitter backoff in seconds +func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime int) int { w.mutex.Lock() defer w.mutex.Unlock() - + if attempt < 2 && currWaitTime < 2 { + return 2 ^ attempt + } + jitterSleepTime := (2 ^ attempt) + w.getJitter(currWaitTime) + return jitterSleepTime } -func (w *waitAlgo) getJitter(currWaitTime int) float64 { +func (w *waitAlgo) getJitter(currWaitTime int) int { multiplicationFactor := chooseRandomFromValues(w.random, []int{-1, 1}) // random int from [-1, 1] jitterAmount := 0.5 * float64(currWaitTime) * float64(multiplicationFactor) - return jitterAmount + return int(jitterAmount) } func chooseRandomFromValues[T any](random *rand.Rand, arr []T) T { @@ -290,7 +291,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { totalTimeout := r.timeout logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout) retryCounter := 0 - sleepTime := time.Duration(1) + sleepTime := 1 // seconds clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10) var requestGUIDReplacer requestGUIDReplacer @@ -334,15 +335,15 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { if totalTimeout > 0 { logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) // if any timeout is set - totalTimeout -= sleepTime - if totalTimeout <= 0 { + totalTimeout -= time.Duration(sleepTime) * time.Second + if totalTimeout <= 0 || retryCounter >= MaxRetryCount { if err != nil { return nil, err } if res != nil { - return nil, fmt.Errorf("timeout after %s. HTTP Status: %v. Hanging?", r.timeout, res.StatusCode) + return nil, fmt.Errorf("timeout after %s and %v retries. HTTP Status: %v. Hanging?", r.timeout, retryCounter, res.StatusCode) } - return nil, fmt.Errorf("timeout after %s. Hanging?", r.timeout) + return nil, fmt.Errorf("timeout after %s and %v retries. Hanging?", r.timeout, retryCounter) } } retryCounter++ @@ -366,7 +367,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout) logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason) - await := time.NewTimer(sleepTime) + await := time.NewTimer(time.Duration(sleepTime) * time.Second) select { case <-await.C: // retry the request diff --git a/retry_test.go b/retry_test.go index f304440f5..fc84d4dbf 100644 --- a/retry_test.go +++ b/retry_test.go @@ -483,3 +483,56 @@ func TestLoginRetry429(t *testing.T) { t.Fatalf("no retry counter should be attached: %v", retryCountKey) } } + +type retryableTc struct { + req *http.Request + res *http.Response + expected bool +} + +func TestIsRetryableError(t *testing.T) { + tcs := []retryableTc{ + { + req: nil, + res: nil, + expected: false, + }, + { + req: nil, + res: &http.Response{StatusCode: http.StatusBadRequest}, + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, + res: nil, + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: heartBeatPath}}, + res: &http.Response{StatusCode: http.StatusBadRequest}, + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, + res: &http.Response{StatusCode: http.StatusNotFound}, + expected: false, + }, + { + req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, + res: &http.Response{StatusCode: http.StatusTooManyRequests}, + expected: true, + }, + { + req: &http.Request{URL: &url.URL{Path: tokenRequestPath}}, + res: &http.Response{StatusCode: http.StatusServiceUnavailable}, + expected: true, + }, + } + + for _, tc := range tcs { + result, _ := isRetryableError(tc.req, tc.res, errUnknownError()) + if result != tc.expected { + t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res) + } + } +} From 47eabb780d33219780c0ae7d70463109d69fde8a Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 13:38:56 +0200 Subject: [PATCH 07/15] Tests added --- retry.go | 49 +++++++++++++++++++++++++++++-------------------- retry_test.go | 27 ++++++++++++++++++++------- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/retry.go b/retry.go index d20de2bba..98182b6b8 100644 --- a/retry.go +++ b/retry.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "math" "math/rand" "net/http" "net/url" @@ -21,12 +22,21 @@ const ( MaxRetryCount = 7 ) +type waitAlgo struct { + mutex *sync.Mutex // required for *rand.Rand usage + random *rand.Rand +} + var random *rand.Rand +var defaultWaitAlgo *waitAlgo var endpointsEligibleForRetry = []string{ loginRequestPath, tokenRequestPath, authenticatorRequestPath, + queryRequestPath, + abortRequestPath, + sessionRequestPath, } var statusCodesEligibleForRetry = []int{ @@ -40,6 +50,7 @@ var statusCodesEligibleForRetry = []int{ func init() { random = rand.New(rand.NewSource(time.Now().UnixNano())) + defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random} } const ( @@ -200,26 +211,23 @@ func isQueryRequest(url *url.URL) bool { return strings.HasPrefix(url.Path, queryRequestPath) } -type waitAlgo struct { - mutex *sync.Mutex // required for random.Int63n - random *rand.Rand -} - // jitter backoff in seconds -func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime int) int { +func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime float64) float64 { w.mutex.Lock() defer w.mutex.Unlock() - if attempt < 2 && currWaitTime < 2 { - return 2 ^ attempt + var jitterPercentage = 0.5 + if attempt < 2 { + jitterPercentage = 0.25 // to ensure there will be sleep time increase between attempts } - jitterSleepTime := (2 ^ attempt) + w.getJitter(currWaitTime) - return jitterSleepTime + jitterAmount := w.getJitter(currWaitTime, jitterPercentage) + jitteredSleepTime := math.Pow(2, float64(attempt)) + jitterAmount + return jitteredSleepTime } -func (w *waitAlgo) getJitter(currWaitTime int) int { +func (w *waitAlgo) getJitter(currWaitTime float64, jitterPercentage float64) float64 { multiplicationFactor := chooseRandomFromValues(w.random, []int{-1, 1}) // random int from [-1, 1] - jitterAmount := 0.5 * float64(currWaitTime) * float64(multiplicationFactor) - return int(jitterAmount) + jitterAmount := jitterPercentage * currWaitTime * float64(multiplicationFactor) + return jitterAmount } func chooseRandomFromValues[T any](random *rand.Rand, arr []T) T { @@ -227,8 +235,6 @@ func chooseRandomFromValues[T any](random *rand.Rand, arr []T) T { return arr[valIdx] } -var defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random} - type requestFunc func(method, urlStr string, body io.Reader) (*http.Request, error) type clientInterface interface { @@ -291,7 +297,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { totalTimeout := r.timeout logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout) retryCounter := 0 - sleepTime := 1 // seconds + sleepTime := 1.0 // seconds clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10) var requestGUIDReplacer requestGUIDReplacer @@ -327,15 +333,16 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } else { logger.WithContext(r.ctx).Warningf( "failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode) + res.Body.Close() } - res.Body.Close() // uses exponential jitter backoff + retryCounter++ sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime) if totalTimeout > 0 { logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) // if any timeout is set - totalTimeout -= time.Duration(sleepTime) * time.Second + totalTimeout -= time.Duration(sleepTime * float64(time.Second)) if totalTimeout <= 0 || retryCounter >= MaxRetryCount { if err != nil { return nil, err @@ -346,7 +353,6 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { return nil, fmt.Errorf("timeout after %s and %v retries. Hanging?", r.timeout, retryCounter) } } - retryCounter++ if requestGUIDReplacer == nil { requestGUIDReplacer = newRequestGUIDReplace(r.fullURL) } @@ -367,7 +373,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout) logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason) - await := time.NewTimer(time.Duration(sleepTime) * time.Second) + await := time.NewTimer(time.Duration(sleepTime * float64(time.Second))) select { case <-await.C: // retry the request @@ -379,6 +385,9 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { } func isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { + if err != nil && res == nil { // Failed http connection. Most probably client timeout. + return true, err + } if res == nil || req == nil { return false, err } diff --git a/retry_test.go b/retry_test.go index fc84d4dbf..d434aa8be 100644 --- a/retry_test.go +++ b/retry_test.go @@ -341,10 +341,10 @@ func TestRetryQuerySuccessWithTimeout(t *testing.T) { } func TestRetryQueryFail(t *testing.T) { - logger.Info("Retry N times and Fail") + logger.Info("Retry N times until there is a timeout and Fail") client := &fakeHTTPClient{ - cnt: 4, - success: false, + statusCode: http.StatusTooManyRequests, + success: false, } urlPtr, err := url.Parse("https://fakeaccountretryfail.snowflakecomputing.com:443/queries/v1/query-request?" + requestIDKey) if err != nil { @@ -352,7 +352,7 @@ func TestRetryQueryFail(t *testing.T) { } _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 30*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err == nil { t.Fatal("should fail to run retry") } @@ -487,29 +487,34 @@ func TestLoginRetry429(t *testing.T) { type retryableTc struct { req *http.Request res *http.Response + err error expected bool } -func TestIsRetryableError(t *testing.T) { +func TestIsRetryable(t *testing.T) { tcs := []retryableTc{ { req: nil, res: nil, + err: nil, expected: false, }, { req: nil, res: &http.Response{StatusCode: http.StatusBadRequest}, + err: nil, expected: false, }, { req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: nil, + err: nil, expected: false, }, { req: &http.Request{URL: &url.URL{Path: heartBeatPath}}, res: &http.Response{StatusCode: http.StatusBadRequest}, + err: nil, expected: false, }, { @@ -517,20 +522,28 @@ func TestIsRetryableError(t *testing.T) { res: &http.Response{StatusCode: http.StatusNotFound}, expected: false, }, + { + req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, + res: nil, + err: errUnknownError(), + expected: true, + }, { req: &http.Request{URL: &url.URL{Path: loginRequestPath}}, res: &http.Response{StatusCode: http.StatusTooManyRequests}, + err: nil, expected: true, }, { - req: &http.Request{URL: &url.URL{Path: tokenRequestPath}}, + req: &http.Request{URL: &url.URL{Path: queryRequestPath}}, res: &http.Response{StatusCode: http.StatusServiceUnavailable}, + err: nil, expected: true, }, } for _, tc := range tcs { - result, _ := isRetryableError(tc.req, tc.res, errUnknownError()) + result, _ := isRetryableError(tc.req, tc.res, tc.err) if result != tc.expected { t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res) } From 34a1a197264639346b1e37c0888af8d5c7e555fa Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 14:18:34 +0200 Subject: [PATCH 08/15] a --- retry_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/retry_test.go b/retry_test.go index d434aa8be..3da937238 100644 --- a/retry_test.go +++ b/retry_test.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "log" "net/http" "net/url" "strconv" @@ -549,3 +550,19 @@ func TestIsRetryable(t *testing.T) { } } } + +func TestExponentialJitterBackoff(t *testing.T) { + retryTimes := make([]float64, 10) + inputTime := 1.0 + for i := 0; i < 10; i++ { + resultTime := defaultWaitAlgo.calculateWaitBeforeRetry(i+1, inputTime) + retryTimes[i] = resultTime + inputTime = resultTime + } + + for i := 0; i < 9; i++ { + if retryTimes[i] >= retryTimes[i+1] { + log.Fatalf("expected consequent values to be greater than previous ones; array: %v", retryTimes) + } + } +} From 4d0b39aa8555a68806f3a8124351faec9dc2fc73 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 14:26:04 +0200 Subject: [PATCH 09/15] Fixed tests ocsp --- retry.go | 1 - 1 file changed, 1 deletion(-) diff --git a/retry.go b/retry.go index 98182b6b8..04a8ea8c3 100644 --- a/retry.go +++ b/retry.go @@ -43,7 +43,6 @@ var statusCodesEligibleForRetry = []int{ http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusBadRequest, - http.StatusForbidden, http.StatusMethodNotAllowed, http.StatusRequestTimeout, } From 1770a630c721753482f4d2a7b148ba016c2b4339 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 14:50:21 +0200 Subject: [PATCH 10/15] a --- driver_test.go | 4 ++-- priv_key_test.go | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/driver_test.go b/driver_test.go index 08f10705f..1d6662ebe 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1414,7 +1414,7 @@ func TestCancelQuery(t *testing.T) { if err == nil { dbt.Fatal("No timeout error returned") } - if !strings.Contains(err.Error(), "context deadline exceeded") { + if err.Error() != "context deadline exceeded" { dbt.Fatalf("Timeout error mismatch: expect %v, receive %v", context.DeadlineExceeded, err.Error()) } }) @@ -1506,7 +1506,7 @@ func TestLargeSetResultCancel(t *testing.T) { time.Sleep(time.Second) cancel() ret := <-c - if !strings.Contains(ret.Error(), "context canceled") { + if ret.Error() != "context canceled" { t.Fatalf("failed to cancel. err: %v", ret) } close(c) diff --git a/priv_key_test.go b/priv_key_test.go index e07430df1..20b6aade6 100644 --- a/priv_key_test.go +++ b/priv_key_test.go @@ -16,7 +16,6 @@ import ( "encoding/pem" "fmt" "os" - "strings" "testing" ) @@ -130,11 +129,11 @@ func TestJWTTokenTimeout(t *testing.T) { defer db.Close() ctx := context.Background() _, err = db.Conn(ctx) - if err == nil || !strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") { - t.Fatalf("expected timeout has not occured") + if err != nil { + t.Fatalf(err.Error()) } invocations := getMocksInvocations(t) - if invocations != 1 { - t.Errorf("Unexpected number of invocations, expected 1, got %v", invocations) + if invocations != 3 { + t.Errorf("Unexpected number of invocations, expected 3, got %v", invocations) } } From 7389fc7badb049492b263e5acede39f01345fb92 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 14:54:42 +0200 Subject: [PATCH 11/15] Fix DSN test --- dsn_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dsn_test.go b/dsn_test.go index 393eda8cf..e3e30d695 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -1119,20 +1119,20 @@ func TestDSN(t *testing.T) { User: "u", Password: "p", Account: "a.b.c", - ClientTimeout: 300 * time.Second, + ClientTimeout: 400 * time.Second, JWTClientTimeout: 60 * time.Second, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&jwtClientTimeout=60&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=400&jwtClientTimeout=60&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ User: "u", Password: "p", Account: "a.b.c", - ClientTimeout: 300 * time.Second, + ClientTimeout: 400 * time.Second, JWTExpireTimeout: 30 * time.Second, }, - dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=300&jwtTimeout=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientTimeout=400&jwtTimeout=30&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, { cfg: &Config{ From 1835cf9c1856460643fd475c2c25f7bd372edb5b Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 14:57:13 +0200 Subject: [PATCH 12/15] Restored defer() --- priv_key_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/priv_key_test.go b/priv_key_test.go index 20b6aade6..adba84c91 100644 --- a/priv_key_test.go +++ b/priv_key_test.go @@ -128,10 +128,12 @@ func TestJWTTokenTimeout(t *testing.T) { } defer db.Close() ctx := context.Background() - _, err = db.Conn(ctx) + conn, err := db.Conn(ctx) if err != nil { t.Fatalf(err.Error()) } + defer conn.Close() + invocations := getMocksInvocations(t) if invocations != 3 { t.Errorf("Unexpected number of invocations, expected 3, got %v", invocations) From 7d1bd74b95af8c353122cb30b3b26bdcdc17be5e Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 15:47:17 +0200 Subject: [PATCH 13/15] Add headers to requests --- auth.go | 2 ++ connection.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/auth.go b/auth.go index f25af41c8..cb258a442 100644 --- a/auth.go +++ b/auth.go @@ -279,6 +279,8 @@ func getHeaders() map[string]string { headers := make(map[string]string) headers[httpHeaderContentType] = headerContentTypeApplicationJSON headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake + headers[httpClientAppId] = clientType + headers[httpClientAppVersion] = SnowflakeGoDriverVersion headers[httpHeaderUserAgent] = userAgent return headers } diff --git a/connection.go b/connection.go index a3c04a0c6..99def3ea0 100644 --- a/connection.go +++ b/connection.go @@ -34,6 +34,8 @@ const ( httpHeaderHost = "Host" httpHeaderValueOctetStream = "application/octet-stream" httpHeaderContentEncoding = "Content-Encoding" + httpClientAppId = "CLIENT_APP_ID" + httpClientAppVersion = "CLIENT_APP_VERSION" ) const ( From f20a03520f53d289e22e65d0ffffdc1f2fd8593d Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Wed, 25 Oct 2023 16:18:00 +0200 Subject: [PATCH 14/15] Fix header name --- auth.go | 2 +- connection.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/auth.go b/auth.go index cb258a442..c3894e43b 100644 --- a/auth.go +++ b/auth.go @@ -279,7 +279,7 @@ func getHeaders() map[string]string { headers := make(map[string]string) headers[httpHeaderContentType] = headerContentTypeApplicationJSON headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake - headers[httpClientAppId] = clientType + headers[httpClientAppID] = clientType headers[httpClientAppVersion] = SnowflakeGoDriverVersion headers[httpHeaderUserAgent] = userAgent return headers diff --git a/connection.go b/connection.go index 99def3ea0..6fcc2c437 100644 --- a/connection.go +++ b/connection.go @@ -34,7 +34,7 @@ const ( httpHeaderHost = "Host" httpHeaderValueOctetStream = "application/octet-stream" httpHeaderContentEncoding = "Content-Encoding" - httpClientAppId = "CLIENT_APP_ID" + httpClientAppID = "CLIENT_APP_ID" httpClientAppVersion = "CLIENT_APP_VERSION" ) From 0159c0f527add63976748ea0d6e6784365da7324 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Thu, 26 Oct 2023 09:03:36 +0200 Subject: [PATCH 15/15] Applied CR suggestion --- retry.go | 11 +++++------ retry_test.go | 16 +++++++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/retry.go b/retry.go index 04a8ea8c3..725a8fc46 100644 --- a/retry.go +++ b/retry.go @@ -18,8 +18,8 @@ import ( ) const ( - // MaxRetryCount specifies maximum number of subsequent retries - MaxRetryCount = 7 + // defaultMaxRetryCount specifies maximum number of subsequent retries + defaultMaxRetryCount = 7 ) type waitAlgo struct { @@ -39,9 +39,8 @@ var endpointsEligibleForRetry = []string{ sessionRequestPath, } -var statusCodesEligibleForRetry = []int{ +var clientErrorsStatusCodesEligibleForRetry = []int{ http.StatusTooManyRequests, - http.StatusServiceUnavailable, http.StatusBadRequest, http.StatusMethodNotAllowed, http.StatusRequestTimeout, @@ -342,7 +341,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout) // if any timeout is set totalTimeout -= time.Duration(sleepTime * float64(time.Second)) - if totalTimeout <= 0 || retryCounter >= MaxRetryCount { + if totalTimeout <= 0 || retryCounter >= defaultMaxRetryCount { if err != nil { return nil, err } @@ -395,5 +394,5 @@ func isRetryableError(req *http.Request, res *http.Response, err error) (bool, e } func isRetryableStatus(statusCode int) bool { - return (statusCode >= 500 && statusCode < 600) || contains(statusCodesEligibleForRetry, statusCode) + return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode) } diff --git a/retry_test.go b/retry_test.go index 3da937238..7a089ecbd 100644 --- a/retry_test.go +++ b/retry_test.go @@ -353,7 +353,7 @@ func TestRetryQueryFail(t *testing.T) { } _, err = newRetryHTTP(context.Background(), client, - emptyRequest, urlPtr, make(map[string]string), 30*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() + emptyRequest, urlPtr, make(map[string]string), 15*time.Second, defaultTimeProvider, nil).doPost().setBody([]byte{0}).execute() if err == nil { t.Fatal("should fail to run retry") } @@ -485,15 +485,13 @@ func TestLoginRetry429(t *testing.T) { } } -type retryableTc struct { - req *http.Request - res *http.Response - err error - expected bool -} - func TestIsRetryable(t *testing.T) { - tcs := []retryableTc{ + tcs := []struct { + req *http.Request + res *http.Response + err error + expected bool + }{ { req: nil, res: nil,