diff --git a/auth.go b/auth.go index 4126355e6..fa2f651c6 100644 --- a/auth.go +++ b/auth.go @@ -70,6 +70,9 @@ func determineAuthenticatorType(cfg *Config, value string) error { } else if upperCaseValue == AuthTypeUsernamePasswordMFA.String() { cfg.Authenticator = AuthTypeUsernamePasswordMFA return nil + } else if upperCaseValue == AuthTypeTokenAccessor.String() { + cfg.Authenticator = AuthTypeTokenAccessor + return nil } else { // possibly Okta case oktaURLString, err := url.QueryUnescape(lowerCaseValue) diff --git a/dsn_test.go b/dsn_test.go index 437297f0a..787f165e8 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -1100,6 +1100,15 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&token=t&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeTokenAccessor, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=tokenaccessor&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, } for _, test := range testcases { dsn, err := DSN(test.cfg) diff --git a/file_transfer_agent.go b/file_transfer_agent.go index 12fd7c44d..36a69d808 100644 --- a/file_transfer_agent.go +++ b/file_transfer_agent.go @@ -1156,21 +1156,21 @@ type snowflakeProgressPercentage struct { func (spp *snowflakeProgressPercentage) call(bytesAmount int64) { if spp.outputStream != nil { spp.seenSoFar += bytesAmount - percentage := percent(spp.seenSoFar, spp.fileSize) + percentage := spp.percent(spp.seenSoFar, spp.fileSize) if !spp.done { - spp.done = updateProgress(spp.filename, spp.startTime, spp.fileSize, percentage, spp.outputStream, spp.showProgressBar) + spp.done = spp.updateProgress(spp.filename, spp.startTime, spp.fileSize, percentage, spp.outputStream, spp.showProgressBar) } } } -func percent(seenSoFar int64, size float64) float64 { +func (spp *snowflakeProgressPercentage) percent(seenSoFar int64, size float64) float64 { if float64(seenSoFar) >= size || size <= 0 { return 1.0 } return float64(seenSoFar) / size } -func updateProgress(filename string, startTime time.Time, totalSize float64, progress float64, outputStream *io.Writer, showProgressBar bool) bool { +func (spp *snowflakeProgressPercentage) updateProgress(filename string, startTime time.Time, totalSize float64, progress float64, outputStream *io.Writer, showProgressBar bool) bool { barLength := 10 totalSize /= mb status := "" diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index d4ad624b8..0cfe7f8e2 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -3,9 +3,11 @@ package gosnowflake import ( + "bytes" "context" "errors" "fmt" + "io" "net/url" "os" "path" @@ -597,3 +599,31 @@ func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { t.Fatalf("should error when creating the temporary directory. Instead errored with: %v", err) } } + +func TestUnitUpdateProgess(t *testing.T) { + var b bytes.Buffer + buf := io.Writer(&b) + buf.Write([]byte("testing")) + + spp := &snowflakeProgressPercentage{ + filename: "test.txt", + fileSize: float64(1500), + outputStream: &buf, + showProgressBar: true, + done: false, + } + + spp.call(0) + if spp.done != false { + t.Fatal("should not be done.") + } + + if spp.seenSoFar != 0 { + t.Fatalf("expected seenSoFar to be 0 but was %v", spp.seenSoFar) + } + + spp.call(1516) + if spp.done != true { + t.Fatal("should be done after updating progess") + } +} diff --git a/test_util.go b/mock_util_test.go similarity index 91% rename from test_util.go rename to mock_util_test.go index f9f7a9b25..40a6b65ba 100644 --- a/test_util.go +++ b/mock_util_test.go @@ -9,6 +9,8 @@ import ( "testing" ) +/** This file contains helper functions for tests only. **/ + func resetHTTPMocks(t *testing.T) { _, err := http.Post("http://localhost:12345/reset", "text/plain", nil) if err != nil { diff --git a/put_get_test.go b/put_get_test.go index e6ec5e0db..926d6c241 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -99,9 +99,10 @@ func TestPercentage(t *testing.T) { {14, 28, 0.5}, } for _, test := range testcases { - if percent(test.seen, test.size) != test.expected { + spp := snowflakeProgressPercentage{} + if spp.percent(test.seen, test.size) != test.expected { t.Fatalf("percentage conversion failed. %v/%v, expected: %v, got: %v", - test.seen, test.size, test.expected, percent(test.seen, test.size)) + test.seen, test.size, test.expected, spp.percent(test.seen, test.size)) } } }