Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-878073 Refactor retry policy to support HTTP 503 & 429 #919

Merged
merged 23 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0ce7e08
SNOW-878073 Refactor retry policy to support HTTP 503 & 429
sfc-gh-dheyman Oct 2, 2023
438adab
Code review suggestions applied
sfc-gh-dheyman Oct 3, 2023
d8b2fb5
Suggestions applied
sfc-gh-dheyman Oct 3, 2023
096df3b
Merge branch 'master' into SNOW-878073-retry-strategy
sfc-gh-dheyman Oct 3, 2023
4a09136
a
sfc-gh-dheyman Oct 11, 2023
da1a79e
Merge branch 'master' into SNOW-878073-retry-strategy
sfc-gh-dheyman Oct 11, 2023
e0219c6
Merge branch 'SNOW-878073-retry-strategy' of github.com:snowflakedb/g…
sfc-gh-dheyman Oct 11, 2023
0c16870
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 18, 2023
e1068d9
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 19, 2023
7590dbb
a
sfc-gh-dheyman Oct 20, 2023
dc20346
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 20, 2023
b58d00c
a
sfc-gh-dheyman Oct 24, 2023
4576da7
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 24, 2023
47eabb7
Tests added
sfc-gh-dheyman Oct 25, 2023
34a1a19
a
sfc-gh-dheyman Oct 25, 2023
4d0b39a
Fixed tests ocsp
sfc-gh-dheyman Oct 25, 2023
1770a63
a
sfc-gh-dheyman Oct 25, 2023
7389fc7
Fix DSN test
sfc-gh-dheyman Oct 25, 2023
1835cf9
Restored defer()
sfc-gh-dheyman Oct 25, 2023
c6a7558
Merge branch 'master' of github.com:snowflakedb/gosnowflake into SNOW…
sfc-gh-dheyman Oct 25, 2023
7d1bd74
Add headers to requests
sfc-gh-dheyman Oct 25, 2023
f20a035
Fix header name
sfc-gh-dheyman Oct 25, 2023
0159c0f
Applied CR suggestion
sfc-gh-dheyman Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func postAuth(

fullURL := sr.getFullURL(loginRequestPath, params)
logger.Infof("full URL: %v", fullURL)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true)
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -279,6 +279,8 @@ func getHeaders() map[string]string {
headers := make(map[string]string)
headers[httpHeaderContentType] = headerContentTypeApplicationJSON
headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake
headers[httpClientAppID] = clientType
headers[httpClientAppVersion] = SnowflakeGoDriverVersion
headers[httpHeaderUserAgent] = userAgent
return headers
}
Expand Down
4 changes: 2 additions & 2 deletions authokta.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func postAuthSAML(
fullURL := sr.getFullURL(authenticatorRequestPath, params)

logger.Infof("fullURL: %v", fullURL)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider, nil)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -274,7 +274,7 @@ func postAuthOKTA(
if err != nil {
return nil, err
}
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider, nil)
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider, nil)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
// InternalClient is implemented by HTTPClient
type InternalClient interface {
Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error)
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error)
}

type httpClient struct {
Expand All @@ -33,7 +33,6 @@ func (cli *httpClient) Post(
headers map[string]string,
body []byte,
timeout time.Duration,
raise4xx bool,
currentTimeProvider currentTimeProvider) (*http.Response, error) {
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider, nil)
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, currentTimeProvider, nil)
}
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) {
t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests)
}

resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider)
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider)
if err != nil || resp.StatusCode != 200 {
t.Fail()
}
Expand Down
2 changes: 2 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ const (
httpHeaderHost = "Host"
httpHeaderValueOctetStream = "application/octet-stream"
httpHeaderContentEncoding = "Content-Encoding"
httpClientAppID = "CLIENT_APP_ID"
httpClientAppVersion = "CLIENT_APP_VERSION"
)

const (
Expand Down
4 changes: 2 additions & 2 deletions driver_ocsp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,8 @@ func TestOCSPFailClosedResponder404(t *testing.T) {
if !ok {
t.Fatalf("failed to extract error URL Error: %v", err)
}
if !strings.Contains(urlErr.Err.Error(), "HTTP Status: 404") {
t.Fatalf("the root cause is not timeout: %v", urlErr.Err)
if !strings.Contains(urlErr.Err.Error(), "404 Not Found") {
t.Fatalf("the root cause is not timeout: %v", urlErr.Err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

const (
defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response
defaultClientTimeout = 300 * time.Second // Timeout for network round trip + read out http response
defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth
defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
Expand Down
8 changes: 4 additions & 4 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1119,20 +1119,20 @@ func TestDSN(t *testing.T) {
User: "u",
Password: "p",
Account: "a.b.c",
ClientTimeout: 300 * time.Second,
ClientTimeout: 400 * time.Second,
JWTClientTimeout: 60 * time.Second,
},
dsn: "u:[email protected]:443?clientTimeout=300&jwtClientTimeout=60&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
dsn: "u:[email protected]:443?clientTimeout=400&jwtClientTimeout=60&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
{
cfg: &Config{
User: "u",
Password: "p",
Account: "a.b.c",
ClientTimeout: 300 * time.Second,
ClientTimeout: 400 * time.Second,
JWTExpireTimeout: 30 * time.Second,
},
dsn: "u:[email protected]:443?clientTimeout=300&jwtTimeout=30&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
dsn: "u:[email protected]:443?clientTimeout=400&jwtTimeout=30&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
},
{
cfg: &Config{
Expand Down
2 changes: 1 addition & 1 deletion heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error {

fullURL := hc.restful.getFullURL(heartBeatPath, params)
timeout := hc.restful.RequestTimeout
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider, nil)
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider, nil)
if err != nil {
return err
}
Expand Down
18 changes: 7 additions & 11 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ const (

type (
funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider, *Config) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error)
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error)
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error)
bodyCreatorType func() ([]byte, error)
)

Expand Down Expand Up @@ -162,14 +162,12 @@ func postRestful(
headers map[string]string,
body []byte,
timeout time.Duration,
raise4XX bool,
currentTimeProvider currentTimeProvider,
cfg *Config) (
*http.Response, error) {
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider, cfg).
doPost().
setBody(body).
doRaise4XX(raise4XX).
execute()
}

Expand All @@ -189,13 +187,11 @@ func postAuthRestful(
fullURL *url.URL,
headers map[string]string,
bodyCreator bodyCreatorType,
timeout time.Duration,
raise4XX bool) (
timeout time.Duration) (
*http.Response, error) {
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil).
doPost().
setBodyCreator(bodyCreator).
doRaise4XX(raise4XX).
execute()
}

Expand Down Expand Up @@ -243,7 +239,7 @@ func postRestfulQueryHelper(

var resp *http.Response
fullURL := sr.getFullURL(queryRequestPath, params)
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider, cfg)
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, cfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -335,7 +331,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati
token, _, _ := sr.TokenAccessor.GetTokens()
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider, nil)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -394,7 +390,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time
return err
}

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider, nil)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -466,7 +462,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time
return err
}

resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider, nil)
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider, nil)
if err != nil {
return err
}
Expand Down
28 changes: 14 additions & 14 deletions restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,63 @@ import (
"time"
)

func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, errors.New("failed to run post method")
}

func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, errors.New("failed to run post method")
}

func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusForbidden,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusInsufficientStorage,
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
}, nil
}

func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR
}, nil
}

func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str
}, nil
}

func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map
}, nil
}

func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
dd := &execResponseData{}
er := &execResponse{
Data: *dd,
Expand All @@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID,
if err != nil {
return err
}
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider, nil)
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) {
accessor := getSimpleTokenAccessor()
oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100)
newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200)
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) {
t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken))
}
Expand Down
Loading
Loading