Skip to content

Commit

Permalink
improve: use context.Background instead of context.TODO (#651)
Browse files Browse the repository at this point in the history
* improve: use context.Background instead of context.TODO

* refactor

---------

Co-authored-by: Piotr Fus <[email protected]>
  • Loading branch information
sivchari and sfc-gh-pfus authored Oct 17, 2023
1 parent c219d9d commit c6c2afd
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 65 deletions.
53 changes: 27 additions & 26 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,27 @@ func TestUnitPostAuth(t *testing.T) {
bodyCreator := func() ([]byte, error) {
return []byte{0x12, 0x34}, nil
}
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err != nil {
t.Fatalf("err: %v", err)
}
sr.FuncAuthPost = postAuthTestError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppBadGatewayError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppForbiddenError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
sr.FuncAuthPost = postAuthTestAppUnexpectedError
_, err = postAuth(context.TODO(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
_, err = postAuth(context.Background(), sr, sr.Client, &url.Values{}, make(map[string]string), bodyCreator, 0)
if err == nil {
t.Fatal("should have failed to auth for unknown reason")
}
Expand Down Expand Up @@ -131,7 +131,8 @@ func postAuthCheckOAuth(
_ *http.Client,
_ *url.Values, _ map[string]string,
bodyCreator bodyCreatorType,
_ time.Duration) (*authResponse, error) {
_ time.Duration,
) (*authResponse, error) {
var ar authRequest
jsonBody, _ := bodyCreator()
if err := json.Unmarshal(jsonBody, &ar); err != nil {
Expand Down Expand Up @@ -408,7 +409,7 @@ func TestUnitAuthenticateWithTokenAccessor(t *testing.T) {
sc.rest = sr

// FuncPostAuth is set to fail, but AuthTypeTokenAccessor should not even make a call to FuncPostAuth
resp, err := authenticate(context.TODO(), sc, []byte{}, []byte{})
resp, err := authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("should not have failed, err %v", err)
}
Expand Down Expand Up @@ -449,7 +450,7 @@ func TestUnitAuthenticate(t *testing.T) {
}
sc.rest = sr

_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -458,7 +459,7 @@ func TestUnitAuthenticate(t *testing.T) {
t.Fatalf("Snowflake error is expected. err: %v", driverErr)
}
sr.FuncPostAuth = postAuthFailWrongAccount
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -467,7 +468,7 @@ func TestUnitAuthenticate(t *testing.T) {
t.Fatalf("Snowflake error is expected. err: %v", driverErr)
}
sr.FuncPostAuth = postAuthFailUnknown
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -477,7 +478,7 @@ func TestUnitAuthenticate(t *testing.T) {
}
ta.SetTokens("bad-token", "bad-master-token", 1)
sr.FuncPostAuth = postAuthSuccessWithErrorCode
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -491,7 +492,7 @@ func TestUnitAuthenticate(t *testing.T) {
}
ta.SetTokens("bad-token", "bad-master-token", 1)
sr.FuncPostAuth = postAuthSuccessWithInvalidErrorCode
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -501,7 +502,7 @@ func TestUnitAuthenticate(t *testing.T) {
}
sr.FuncPostAuth = postAuthSuccess
var resp *authResponseMain
resp, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
resp, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to auth. err: %v", err)
}
Expand Down Expand Up @@ -533,7 +534,7 @@ func TestUnitAuthenticateSaml(t *testing.T) {
Host: "blah.okta.com",
}
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte("HTML data in bytes from"), []byte{})
_, err = authenticate(context.Background(), sc, []byte("HTML data in bytes from"), []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
Expand All @@ -550,7 +551,7 @@ func TestUnitAuthenticateOAuth(t *testing.T) {
sc.cfg.Token = "oauthToken"
sc.cfg.Authenticator = AuthTypeOAuth
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
Expand All @@ -566,14 +567,14 @@ func TestUnitAuthenticatePasscode(t *testing.T) {
sc.cfg.Passcode = "987654321"
sc.rest = sr

_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
sr.FuncPostAuth = postAuthCheckPasscodeInPassword
sc.rest = sr
sc.cfg.PasscodeInPassword = true
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
Expand All @@ -594,7 +595,7 @@ func TestUnitAuthenticateJWT(t *testing.T) {
sc.rest = sr

// A valid JWT token should pass
if _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}); err != nil {
if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err != nil {
t.Fatalf("failed to run. err: %v", err)
}

Expand All @@ -604,7 +605,7 @@ func TestUnitAuthenticateJWT(t *testing.T) {
t.Error(err)
}
sc.cfg.PrivateKey = invalidPrivateKey
if _, err = authenticate(context.TODO(), sc, []byte{}, []byte{}); err == nil {
if _, err = authenticate(context.Background(), sc, []byte{}, []byte{}); err == nil {
t.Fatalf("invalid token passed")
}
}
Expand All @@ -619,20 +620,20 @@ func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaToken
sc.cfg.MfaToken = "mockedMfaToken"
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckUsernamePasswordMfaFailed
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
Expand All @@ -648,7 +649,7 @@ func TestUnitAuthenticateWithConfigMFA(t *testing.T) {
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
sc.ctx = context.TODO()
sc.ctx = context.Background()
err = authenticateWithConfig(sc)
if err != nil {
t.Fatalf("failed to run. err: %v", err)
Expand All @@ -665,20 +666,20 @@ func TestUnitAuthenticateExternalBrowser(t *testing.T) {
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckExternalBrowserToken
sc.cfg.IDToken = "mockedIDToken"
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}

sr.FuncPostAuth = postAuthCheckExternalBrowserFailed
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
_, err = authenticate(context.Background(), sc, []byte{}, []byte{})
if err == nil {
t.Fatal("should have failed")
}
Expand Down
8 changes: 4 additions & 4 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFail
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode
_, _, err = authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
_, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -128,7 +128,7 @@ func TestAuthenticationTimeout(t *testing.T) {
FuncPostAuthSAML: postAuthExternalBrowserError,
TokenAccessor: getSimpleTokenAccessor(),
}
_, _, err := authenticateByExternalBrowser(context.TODO(), sr, authenticator, application, account, user, password, timeout)
_, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout)
if err.Error() != "authentication timed out" {
t.Fatal("should have timed out")
}
Expand Down
32 changes: 16 additions & 16 deletions authokta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ func TestUnitPostAuthSAML(t *testing.T) {
TokenAccessor: getSimpleTokenAccessor(),
}
var err error
_, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0)
_, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPost = postTestAppBadGatewayError
_, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{}, 0)
_, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPost = postTestSuccessButInvalidJSON
_, err = postAuthSAML(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, 0)
_, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, 0)
if err == nil {
t.Fatalf("should have failed to post")
}
Expand All @@ -86,17 +86,17 @@ func TestUnitPostAuthOKTA(t *testing.T) {
TokenAccessor: getSimpleTokenAccessor(),
}
var err error
_, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0)
_, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPost = postTestAppBadGatewayError
_, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{}, "hahah", 0)
_, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPost = postTestSuccessButInvalidJSON
_, err = postAuthOKTA(context.TODO(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0)
_, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0)
if err == nil {
t.Fatal("should have failed to run post request after the renewal")
}
Expand All @@ -108,17 +108,17 @@ func TestUnitGetSSO(t *testing.T) {
TokenAccessor: getSimpleTokenAccessor(),
}
var err error
_, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncGet = getTestAppBadGatewayError
_, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncGet = getTestHTMLSuccess
_, err = getSSO(context.TODO(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
if err != nil {
t.Fatalf("failed to get HTML content. err: %v", err)
}
Expand Down Expand Up @@ -194,17 +194,17 @@ func TestUnitAuthenticateBySAML(t *testing.T) {
TokenAccessor: getSimpleTokenAccessor(),
}
var err error
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthSAMLAuthFail
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
Expand All @@ -217,23 +217,23 @@ func TestUnitAuthenticateBySAML(t *testing.T) {
}
sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess
sr.FuncPostAuthOKTA = postAuthOKTAError
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncPostAuthOKTA = postAuthOKTASuccess
sr.FuncGetSSO = getSSOError
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncGetSSO = getSSOSuccessButInvalidURL
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err == nil {
t.Fatal("should have failed.")
}
sr.FuncGetSSO = getSSOSuccess
_, err = authenticateBySAML(context.TODO(), sr, authenticator, application, account, user, password)
_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
if err != nil {
t.Fatalf("failed. err: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func TestServiceName(t *testing.T) {

expectServiceName := serviceNameStub
for i := 0; i < 5; i++ {
sc.exec(context.TODO(), "", false, /* noResult */
sc.exec(context.Background(), "", false, /* noResult */
false /* isInternal */, false /* describeOnly */, nil)
if actualServiceName, ok := sc.cfg.Params[serviceName]; ok {
if *actualServiceName != expectServiceName {
Expand Down
4 changes: 2 additions & 2 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func TestCtxVal(t *testing.T) {

func TestLogEntryCtx(t *testing.T) {
var log = logger
var ctx1 = context.WithValue(context.TODO(), SFSessionIDKey, "sessID1")
var ctx2 = context.WithValue(context.TODO(), SFSessionUserKey, "admin")
var ctx1 = context.WithValue(context.Background(), SFSessionIDKey, "sessID1")
var ctx2 = context.WithValue(context.Background(), SFSessionUserKey, "admin")

fs1 := context2Fields(ctx1)
fs2 := context2Fields(ctx2)
Expand Down
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type SnowflakeDriver struct{}
// Open creates a new connection.
func (d SnowflakeDriver) Open(dsn string) (driver.Conn, error) {
logger.Info("Open")
ctx := context.TODO()
ctx := context.Background()
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit c6c2afd

Please sign in to comment.