diff --git a/auth.go b/auth.go index fa2f651c6..8aa47b4f0 100644 --- a/auth.go +++ b/auth.go @@ -226,7 +226,7 @@ func postAuth( fullURL := sr.getFullURL(loginRequestPath, params) logger.Infof("full URL: %v", fullURL) - resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true) + resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout) if err != nil { return nil, err } diff --git a/authokta.go b/authokta.go index 994b51c13..2a490a6dc 100644 --- a/authokta.go +++ b/authokta.go @@ -216,7 +216,7 @@ func postAuthSAML( fullURL := sr.getFullURL(authenticatorRequestPath, params) logger.Infof("fullURL: %v", fullURL) - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider) if err != nil { return nil, err } @@ -274,7 +274,7 @@ func postAuthOKTA( if err != nil { return nil, err } - resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider) if err != nil { return nil, err } diff --git a/client.go b/client.go index a5d7f747c..edfb9543f 100644 --- a/client.go +++ b/client.go @@ -12,7 +12,7 @@ import ( // InternalClient is implemented by HTTPClient type InternalClient interface { Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error) - Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error) + Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error) } type httpClient struct { @@ -33,7 +33,6 @@ func (cli *httpClient) Post( headers map[string]string, body []byte, timeout time.Duration, - raise4xx bool, currentTimeProvider currentTimeProvider) (*http.Response, error) { - return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider) + return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, currentTimeProvider) } diff --git a/client_test.go b/client_test.go index 6ef66ef30..2e8beaa66 100644 --- a/client_test.go +++ b/client_test.go @@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) { t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests) } - resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider) + resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider) if err != nil || resp.StatusCode != 200 { t.Fail() } diff --git a/heartbeat.go b/heartbeat.go index 60c4bf10e..a4b6d6254 100644 --- a/heartbeat.go +++ b/heartbeat.go @@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error { fullURL := hc.restful.getFullURL(heartBeatPath, params) timeout := hc.restful.RequestTimeout - resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider) + resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider) if err != nil { return err } diff --git a/restful.go b/restful.go index 34297c1d9..1ec6119af 100644 --- a/restful.go +++ b/restful.go @@ -43,8 +43,8 @@ const ( type ( funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error) - funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error) - funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error) + funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error) + funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error) bodyCreatorType func() ([]byte, error) ) @@ -162,13 +162,11 @@ func postRestful( headers map[string]string, body []byte, timeout time.Duration, - raise4XX bool, currentTimeProvider currentTimeProvider) ( *http.Response, error) { return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider). doPost(). setBody(body). - doRaise4XX(raise4XX). execute() } @@ -188,13 +186,11 @@ func postAuthRestful( fullURL *url.URL, headers map[string]string, bodyCreator bodyCreatorType, - timeout time.Duration, - raise4XX bool) ( + timeout time.Duration) ( *http.Response, error) { return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider). doPost(). setBodyCreator(bodyCreator). - doRaise4XX(raise4XX). execute() } @@ -242,7 +238,7 @@ func postRestfulQueryHelper( var resp *http.Response fullURL := sr.getFullURL(queryRequestPath, params) - resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider) + resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider) if err != nil { return nil, err } @@ -334,7 +330,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati token, _, _ := sr.TokenAccessor.GetTokens() headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token) - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider) if err != nil { return err } @@ -393,7 +389,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time return err } - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider) if err != nil { return err } @@ -465,7 +461,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time return err } - resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider) if err != nil { return err } diff --git a/restful_test.go b/restful_test.go index f23cacd97..7aa32a7d4 100644 --- a/restful_test.go +++ b/restful_test.go @@ -15,63 +15,63 @@ import ( "time" ) -func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, errors.New("failed to run post method") } -func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, errors.New("failed to run post method") } -func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadGateway, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusInsufficientStorage, Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, }, nil } -func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR }, nil } -func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str }, nil } -func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) { +func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map }, nil } -func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { dd := &execResponseData{} er := &execResponse{ Data: *dd, @@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID, if err != nil { return err } - resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider) + resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider) if err != nil { return err } @@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) { accessor := getSimpleTokenAccessor() oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100) newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200) - postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { + postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) { t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken)) } diff --git a/retry.go b/retry.go index f465b6609..5a994ed53 100644 --- a/retry.go +++ b/retry.go @@ -5,13 +5,12 @@ package gosnowflake import ( "bytes" "context" - "crypto/x509" "fmt" + "golang.org/x/exp/slices" "io" "math/rand" "net/http" "net/url" - "runtime" "strconv" "strings" "sync" @@ -20,6 +19,15 @@ import ( var random *rand.Rand +var endpointsEligibleForRetry = []string{ + loginRequestPath, + queryRequestPath, + tokenRequestPath, + authenticatorRequestPath, +} + +var statusCodesEligibleForRetry = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} + func init() { random = rand.New(rand.NewSource(time.Now().UnixNano())) } @@ -222,7 +230,6 @@ type retryHTTP struct { headers map[string]string bodyCreator bodyCreatorType timeout time.Duration - raise4XX bool currentTimeProvider currentTimeProvider } @@ -242,16 +249,10 @@ func newRetryHTTP(ctx context.Context, instance.headers = headers instance.timeout = timeout instance.bodyCreator = emptyBodyCreator - instance.raise4XX = false instance.currentTimeProvider = currentTimeProvider return &instance } -func (r *retryHTTP) doRaise4XX(raise4XX bool) *retryHTTP { - r.raise4XX = raise4XX - return r -} - func (r *retryHTTP) doPost() *retryHTTP { r.method = "POST" return r @@ -298,27 +299,19 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { req.Header.Set(k, v) } res, err = r.client.Do(req) + // check if it can retry. + retryable, err := r.isRetryableError(req, res, err) + if !retryable { + return res, err + } if err != nil { - // check if it can retry. - doExit, err := r.isRetryableError(err) - if doExit { - return res, err - } - // cannot just return 4xx and 5xx status as the error can be sporadic. run often helps. logger.WithContext(r.ctx).Warningf( - "failed http connection. no response is returned. err: %v. retrying...\n", err) + "failed http connection. err: %v. retrying...\n", err) } else { - if res.StatusCode == http.StatusOK || r.raise4XX && res != nil && res.StatusCode >= 400 && res.StatusCode < 500 && res.StatusCode != 429 { - // exit if success - // or - // abort connection if raise4XX flag is enabled and the range of HTTP status code are 4XX. - // This is currently used for Snowflake login. The caller must generate an error object based on HTTP status. - break - } logger.WithContext(r.ctx).Warningf( "failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode) - res.Body.Close() } + res.Body.Close() // uses decorrelated jitter backoff sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime) @@ -366,36 +359,13 @@ func (r *retryHTTP) execute() (res *http.Response, err error) { return res, r.ctx.Err() } } - return res, err } -func (r *retryHTTP) isRetryableError(err error) (bool, error) { - urlError, isURLError := err.(*url.Error) - if isURLError { - // context cancel or timeout - if urlError.Err == context.DeadlineExceeded || urlError.Err == context.Canceled { - return true, urlError.Err - } - if driverError, ok := urlError.Err.(*SnowflakeError); ok { - // Certificate Revoked - if driverError.Number == ErrOCSPStatusRevoked { - return true, err - } - } - if _, ok := urlError.Err.(x509.CertificateInvalidError); ok { - // Certificate is invalid - return true, err - } - if _, ok := urlError.Err.(x509.UnknownAuthorityError); ok { - // Certificate is self-signed - return true, err - } - errString := urlError.Err.Error() - if runtime.GOOS == "darwin" && strings.HasPrefix(errString, "x509:") && strings.HasSuffix(errString, "certificate is expired") { - // Certificate is expired - return true, err - } - +func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) { + if res == nil { + return false, err } - return false, err + isRetryableURL := slices.Contains(endpointsEligibleForRetry, req.URL.Path) + isRetryableStatus := slices.Contains(statusCodesEligibleForRetry, res.StatusCode) + return isRetryableURL && isRetryableStatus, err } diff --git a/retry_test.go b/retry_test.go index e8590a21b..9c2b5ac3f 100644 --- a/retry_test.go +++ b/retry_test.go @@ -421,7 +421,7 @@ func TestLoginRetry429(t *testing.T) { } _, err = newRetryHTTP(context.TODO(), client, - emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX + emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX if err != nil { t.Fatal("failed to run retry") } diff --git a/telemetry.go b/telemetry.go index 1adb085a4..eb874cc93 100644 --- a/telemetry.go +++ b/telemetry.go @@ -97,7 +97,7 @@ func (st *snowflakeTelemetry) sendBatch() error { } resp, err := st.sr.FuncPost(context.Background(), st.sr, st.sr.getFullURL(telemetryPath, nil), headers, body, - defaultTelemetryTimeout, true, defaultTimeProvider) + defaultTelemetryTimeout, defaultTimeProvider) if err != nil { logger.Info("failed to upload metrics to telemetry. err: %v", err) return err diff --git a/telemetry_test.go b/telemetry_test.go index c6a7436a9..941d1c6d9 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -107,7 +107,7 @@ func TestEnableTelemetry(t *testing.T) { }) } -func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) { +func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) { return nil, errors.New("failed to upload metrics to telemetry") }