Skip to content

Commit

Permalink
SNOW-878073 Refactor retry policy to support HTTP 503 & 429
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Oct 2, 2023
1 parent 894e78c commit 0ce7e08
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 89 deletions.
2 changes: 1 addition & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func postAuth(

fullURL := sr.getFullURL(loginRequestPath, params)
logger.Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions authokta.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func postAuthSAML(
fullURL := sr.getFullURL(authenticatorRequestPath, params)

logger.Infof("fullURL: %v", fullURL)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -274,7 +274,7 @@ func postAuthOKTA(
if err != nil {
return nil, err
}
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
// InternalClient is implemented by HTTPClient
type InternalClient interface {
Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error)
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error)
}

type httpClient struct {
Expand All @@ -33,7 +33,6 @@ func (cli *httpClient) Post(
headers map[string]string,
body []byte,
timeout time.Duration,
raise4xx bool,
currentTimeProvider currentTimeProvider) (*http.Response, error) {
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider)
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, currentTimeProvider)
}
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) {
t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests)
}

resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider)
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider)
if err != nil || resp.StatusCode != 200 {
t.Fail()
}
Expand Down
2 changes: 1 addition & 1 deletion heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error {

fullURL := hc.restful.getFullURL(heartBeatPath, params)
timeout := hc.restful.RequestTimeout
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider)
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down
18 changes: 7 additions & 11 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ const (

type (
funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error)
bodyCreatorType func() ([]byte, error)
)

Expand Down Expand Up @@ -162,13 +162,11 @@ func postRestful(
headers map[string]string,
body []byte,
timeout time.Duration,
raise4XX bool,
currentTimeProvider currentTimeProvider) (
*http.Response, error) {
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider).
doPost().
setBody(body).
doRaise4XX(raise4XX).
execute()
}

Expand All @@ -188,13 +186,11 @@ func postAuthRestful(
fullURL *url.URL,
headers map[string]string,
bodyCreator bodyCreatorType,
timeout time.Duration,
raise4XX bool) (
timeout time.Duration) (
*http.Response, error) {
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider).
doPost().
setBodyCreator(bodyCreator).
doRaise4XX(raise4XX).
execute()
}

Expand Down Expand Up @@ -242,7 +238,7 @@ func postRestfulQueryHelper(

var resp *http.Response
fullURL := sr.getFullURL(queryRequestPath, params)
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -334,7 +330,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider)
if err != nil {
return err
}
Expand Down Expand Up @@ -393,7 +389,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time
return err
}

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down Expand Up @@ -465,7 +461,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time
return err
}

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down
28 changes: 14 additions & 14 deletions restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,63 @@ import (
"time"
)

func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, errors.New("failed to run post method")
}

func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, errors.New("failed to run post method")
}

func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusInsufficientStorage,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR
}, nil
}

func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str
}, nil
}

func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map
}, nil
}

func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID,
if err != nil {
return err
}
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider)
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider)
if err != nil {
return err
}
Expand Down Expand Up @@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) {
accessor := getSimpleTokenAccessor()
oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100)
newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200)
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider) (*http.Response, error) {
if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) {
t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken))
}
Expand Down
76 changes: 23 additions & 53 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ package gosnowflake
import (
"bytes"
"context"
"crypto/x509"
"fmt"
"golang.org/x/exp/slices"
"io"
"math/rand"
"net/http"
"net/url"
"runtime"
"strconv"
"strings"
"sync"
Expand All @@ -20,6 +19,15 @@ import (

var random *rand.Rand

var endpointsEligibleForRetry = []string{
loginRequestPath,
queryRequestPath,
tokenRequestPath,
authenticatorRequestPath,
}

var statusCodesEligibleForRetry = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}

func init() {
random = rand.New(rand.NewSource(time.Now().UnixNano()))
}
Expand Down Expand Up @@ -222,7 +230,6 @@ type retryHTTP struct {
headers map[string]string
bodyCreator bodyCreatorType
timeout time.Duration
raise4XX bool
currentTimeProvider currentTimeProvider
}

Expand All @@ -242,16 +249,10 @@ func newRetryHTTP(ctx context.Context,
instance.headers = headers
instance.timeout = timeout
instance.bodyCreator = emptyBodyCreator
instance.raise4XX = false
instance.currentTimeProvider = currentTimeProvider
return &instance
}

func (r *retryHTTP) doRaise4XX(raise4XX bool) *retryHTTP {
r.raise4XX = raise4XX
return r
}

func (r *retryHTTP) doPost() *retryHTTP {
r.method = "POST"
return r
Expand Down Expand Up @@ -298,27 +299,19 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
req.Header.Set(k, v)
}
res, err = r.client.Do(req)
// check if it can retry.
retryable, err := r.isRetryableError(req, res, err)
if !retryable {
return res, err
}
if err != nil {
// check if it can retry.
doExit, err := r.isRetryableError(err)
if doExit {
return res, err
}
// cannot just return 4xx and 5xx status as the error can be sporadic. run often helps.
logger.WithContext(r.ctx).Warningf(
"failed http connection. no response is returned. err: %v. retrying...\n", err)
"failed http connection. err: %v. retrying...\n", err)
} else {
if res.StatusCode == http.StatusOK || r.raise4XX && res != nil && res.StatusCode >= 400 && res.StatusCode < 500 && res.StatusCode != 429 {
// exit if success
// or
// abort connection if raise4XX flag is enabled and the range of HTTP status code are 4XX.
// This is currently used for Snowflake login. The caller must generate an error object based on HTTP status.
break
}
logger.WithContext(r.ctx).Warningf(
"failed http connection. HTTP Status: %v. retrying...\n", res.StatusCode)
res.Body.Close()
}
res.Body.Close()
// uses decorrelated jitter backoff
sleepTime = defaultWaitAlgo.decorr(retryCounter, sleepTime)

Expand Down Expand Up @@ -366,36 +359,13 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
return res, r.ctx.Err()
}
}
return res, err
}

func (r *retryHTTP) isRetryableError(err error) (bool, error) {
urlError, isURLError := err.(*url.Error)
if isURLError {
// context cancel or timeout
if urlError.Err == context.DeadlineExceeded || urlError.Err == context.Canceled {
return true, urlError.Err
}
if driverError, ok := urlError.Err.(*SnowflakeError); ok {
// Certificate Revoked
if driverError.Number == ErrOCSPStatusRevoked {
return true, err
}
}
if _, ok := urlError.Err.(x509.CertificateInvalidError); ok {
// Certificate is invalid
return true, err
}
if _, ok := urlError.Err.(x509.UnknownAuthorityError); ok {
// Certificate is self-signed
return true, err
}
errString := urlError.Err.Error()
if runtime.GOOS == "darwin" && strings.HasPrefix(errString, "x509:") && strings.HasSuffix(errString, "certificate is expired") {
// Certificate is expired
return true, err
}

func (r *retryHTTP) isRetryableError(req *http.Request, res *http.Response, err error) (bool, error) {
if res == nil {
return false, err
}
return false, err
isRetryableURL := slices.Contains(endpointsEligibleForRetry, req.URL.Path)
isRetryableStatus := slices.Contains(statusCodesEligibleForRetry, res.StatusCode)
return isRetryableURL && isRetryableStatus, err
}
2 changes: 1 addition & 1 deletion retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func TestLoginRetry429(t *testing.T) {
}
_, err = newRetryHTTP(context.TODO(),
client,
emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doRaise4XX(true).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX
emptyRequest, urlPtr, make(map[string]string), 60*time.Second, defaultTimeProvider).doPost().setBody([]byte{0}).execute() // enable doRaise4XXX
if err != nil {
t.Fatal("failed to run retry")
}
Expand Down
Loading

0 comments on commit 0ce7e08

Please sign in to comment.