Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-878073 Refactor retry policy to support HTTP 503 & 429 #919

Merged
merged 23 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0ce7e08
SNOW-878073 Refactor retry policy to support HTTP 503 & 429
sfc-gh-dheyman Oct 2, 2023
438adab
Code review suggestions applied
sfc-gh-dheyman Oct 3, 2023
d8b2fb5
Suggestions applied
sfc-gh-dheyman Oct 3, 2023
096df3b
Merge branch 'master' into SNOW-878073-retry-strategy
sfc-gh-dheyman Oct 3, 2023
4a09136
a
sfc-gh-dheyman Oct 11, 2023
da1a79e
Merge branch 'master' into SNOW-878073-retry-strategy
sfc-gh-dheyman Oct 11, 2023
e0219c6
Merge branch 'SNOW-878073-retry-strategy' of github.com:snowflakedb/g…
sfc-gh-dheyman Oct 11, 2023
0c16870
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 18, 2023
e1068d9
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 19, 2023
7590dbb
a
sfc-gh-dheyman Oct 20, 2023
dc20346
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 20, 2023
b58d00c
a
sfc-gh-dheyman Oct 24, 2023
4576da7
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 24, 2023
47eabb7
Tests added
sfc-gh-dheyman Oct 25, 2023
34a1a19
a
sfc-gh-dheyman Oct 25, 2023
4d0b39a
Fixed tests ocsp
sfc-gh-dheyman Oct 25, 2023
1770a63
a
sfc-gh-dheyman Oct 25, 2023
7389fc7
Fix DSN test
sfc-gh-dheyman Oct 25, 2023
1835cf9
Restored defer()
sfc-gh-dheyman Oct 25, 2023
c6a7558
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 25, 2023
7d1bd74
Add headers to requests
sfc-gh-dheyman Oct 25, 2023
f20a035
Fix header name
sfc-gh-dheyman Oct 25, 2023
0159c0f
Applied CR suggestion
sfc-gh-dheyman Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func postAuth(

fullURL := sr.getFullURL(loginRequestPath, params)
logger.Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions authokta.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func postAuthSAML(
fullURL := sr.getFullURL(authenticatorRequestPath, params)

logger.Infof("fullURL: %v", fullURL)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -274,7 +274,7 @@ func postAuthOKTA(
if err != nil {
return nil, err
}
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
// InternalClient is implemented by HTTPClient
type InternalClient interface {
Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error)
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error)
}

type httpClient struct {
Expand All @@ -33,7 +33,6 @@ func (cli *httpClient) Post(
headers map[string]string,
body []byte,
timeout time.Duration,
raise4xx bool,
currentTimeProvider currentTimeProvider) (*http.Response, error) {
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider)
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, currentTimeProvider)
}
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) {
t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests)
}

resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider)
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider)
if err != nil || resp.StatusCode != 200 {
t.Fail()
}
Expand Down
2 changes: 1 addition & 1 deletion heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error {

fullURL := hc.restful.getFullURL(heartBeatPath, params)
timeout := hc.restful.RequestTimeout
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider)
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down
18 changes: 7 additions & 11 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ const (

type (
funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error)
bodyCreatorType func() ([]byte, error)
)

Expand Down Expand Up @@ -162,13 +162,11 @@ func postRestful(
headers map[string]string,
body []byte,
timeout time.Duration,
raise4XX bool,
currentTimeProvider currentTimeProvider) (
*http.Response, error) {
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider).
doPost().
setBody(body).
doRaise4XX(raise4XX).
execute()
}

Expand All @@ -188,13 +186,11 @@ func postAuthRestful(
fullURL *url.URL,
headers map[string]string,
bodyCreator bodyCreatorType,
timeout time.Duration,
raise4XX bool) (
timeout time.Duration) (
*http.Response, error) {
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider).
doPost().
setBodyCreator(bodyCreator).
doRaise4XX(raise4XX).
execute()
}

Expand Down Expand Up @@ -242,7 +238,7 @@ func postRestfulQueryHelper(

var resp *http.Response
fullURL := sr.getFullURL(queryRequestPath, params)
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -334,7 +330,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider)
if err != nil {
return err
}
Expand Down Expand Up @@ -393,7 +389,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time
return err
}

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down Expand Up @@ -465,7 +461,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time
return err
}

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down
28 changes: 14 additions & 14 deletions restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,63 @@ import (
"time"
)

func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, errors.New("failed to run post method")
}

func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, errors.New("failed to run post method")
}

func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusInsufficientStorage,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR
}, nil
}

func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str
}, nil
}

func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map
}, nil
}

func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID,
if err != nil {
return err
}
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down Expand Up @@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) {
accessor := getSimpleTokenAccessor()
oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100)
newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200)
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) {
t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken))
}
Expand Down
76 changes: 23 additions & 53 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ package gosnowflake
import (
"bytes"
"context"
"crypto/x509"
"fmt"
"golang.org/x/exp/slices"
sfc-gh-dheyman marked this conversation as resolved.
Show resolved Hide resolved
"io"
"math/rand"
"net/http"
"net/url"
"runtime"
"strconv"
"strings"
"sync"
Expand All @@ -20,6 +19,15 @@ import (

var random *rand.Rand

var endpointsEligibleForRetry = []string{
loginRequestPath,
queryRequestPath,
sfc-gh-dheyman marked this conversation as resolved.
Show resolved Hide resolved
tokenRequestPath,
authenticatorRequestPath,
}

var statusCodesEligibleForRetry = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}
sfc-gh-dheyman marked this conversation as resolved.
Show resolved Hide resolved

func init() {
random = rand.New(rand.NewSource(time.Now().UnixNano()))
}
Expand Down Expand Up @@ -222,7 +230,6 @@ type retryHTTP struct {
headers map[string]string
bodyCreator bodyCreatorType
timeout time.Duration
raise4XX bool
currentTimeProvider currentTimeProvider
}

Expand All @@ -242,16 +249,10 @@ func newRetryHTTP(ctx context.Context,
instance.headers = headers
instance.timeout = timeout
instance.bodyCreator = emptyBodyCreator
instance.raise4XX = false
instance.currentTimeProvider = currentTimeProvider
return &instance
}

func (r *retryHTTP) doRaise4XX(raise4XX bool) *retryHTTP {
r.raise4XX = raise4XX
return r
}

func (r *retryHTTP) doPost() *retryHTTP {
r.method = "POST"
return r
Expand Down Expand Up @@ -298,27 +299,19 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
req.Header.Set(k, v)
}
res, err = r.client.Do(req)
// check if it can retry.
retryable, err := r.isRetryableError(req, res, err)
if !retryable {
return res, err
}
if err != nil {
// check if it can retry.
doExit, err := r.isRetryableError(err)
if doExit {
return res, err
}
// cannot just return 4xx and 5xx status as the error can be sporadic. run often helps.
logger.WithContext(r.ctx).Warningf(
"failed http connection. no response is returned. err: %v. retrying...\n", err)
"failed http connection. err: %v. retrying...\n", err)
} else {
if res.StatusCode == http.StatusOK || r.raise4XX && res != nil && res.StatusCode >= 400 && res.StatusCode < 500 && res.StatusCode != 429 {
// exit if success
// or
// abort connection if raise4XX flag is enabled and the range of HTTP status code are 4XX.
// This is currently used for Snowflake login. The caller must generate an error object based on HTTP status.
break
}
logger.WithContext(r.ctx).Warningf(
"failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode)
res.Body.Close()
}
res.Body.Close()
sfc-gh-dheyman marked this conversation as resolved.
Show resolved Hide resolved
// uses decorrelated jitter backoff
sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime)

Expand Down Expand Up @@ -366,36 +359,13 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
return res, r.ctx.Err()
}
}
return res, err
}

func (r *retryHTTP) isRetryableError(err error) (bool, error) {
urlError, isURLError := err.(*url.Error)
if isURLError {
// context cancel or timeout
if urlError.Err == context.DeadlineExceeded || urlError.Err == context.Canceled {
return true, urlError.Err
}
if driverError, ok := urlError.Err.(*SnowflakeError); ok {
// Certificate Revoked
if driverError.Number == ErrOCSPStatusRevoked {
return true, err
}
}
if _, ok := urlError.Err.(x509.CertificateInvalidError); ok {
// Certificate is invalid
return true, err
}
if _, ok := urlError.Err.(x509.UnknownAuthorityError); ok {
// Certificate is self-signed
return true, err
}
errString := urlError.Err.Error()
if runtime.GOOS == "darwin" && strings.HasPrefix(errString, "x509:") && strings.HasSuffix(errString, "certificate is expired") {
// Certificate is expired
return true, err
}

func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) {
if res == nil {
return false, err
sfc-gh-pfus marked this conversation as resolved.
Show resolved Hide resolved
}
return false, err
isRetryableURL := slices.Contains(endpointsEligibleForRetry, req.URL.Path)
isRetryableStatus := slices.Contains(statusCodesEligibleForRetry, res.StatusCode)
return isRetryableURL && isRetryableStatus, err
}
2 changes: 1 addition & 1 deletion retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func TestLoginRetry429(t *testing.T) {
}
_, err = newRetryHTTP(context.TODO(),
client,
emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX
emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX
if err != nil {
t.Fatal("failed to run retry")
}
Expand Down
Loading
Loading