From 90ba68d2f5b06c2d8955ee1889af465b8f718707 Mon Sep 17 00:00:00 2001 From: Przemyslaw Motacki Date: Fri, 8 Sep 2023 06:57:10 +0200 Subject: [PATCH] SNOW-911146 Extract method of preparing connection in tests --- chunk_test.go | 417 ++++++++++---------- connection_test.go | 303 ++++++--------- converter_test.go | 179 ++++----- driver_test.go | 51 ++- file_transfer_agent_test.go | 750 ++++++++++++++++-------------------- heartbeat_test.go | 76 ++-- htap_test.go | 22 +- multistatement_test.go | 192 +++++---- put_get_with_aws_test.go | 265 ++++++------- rows_test.go | 50 +-- telemetry_test.go | 508 +++++++++++------------- 11 files changed, 1255 insertions(+), 1558 deletions(-) diff --git a/chunk_test.go b/chunk_test.go index c9a158602..b58b437a9 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -395,253 +395,226 @@ func TestWithStreamDownloader(t *testing.T) { } func TestWithArrowBatches(t *testing.T) { - ctx := WithArrowBatches(context.Background()) - numrows := 3000 // approximately 6 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := WithArrowBatches(context.Background()) + numrows := 3000 // approximately 6 ArrowBatch objects - pool := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer pool.AssertSize(t, 0) - ctx = WithArrowAllocator(ctx, pool) + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx = WithArrowAllocator(ctx, pool) - query := fmt.Sprintf(selectRandomGenerator, numrows) - rows, err := sc.QueryContext(ctx, query, []driver.NamedValue{}) - if err != nil { - t.Error(err) - } - defer rows.Close() + query := fmt.Sprintf(selectRandomGenerator, numrows) + rows, err := sct.sc.QueryContext(ctx, query, []driver.NamedValue{}) + if err != nil { + t.Error(err) + } + defer rows.Close() - // getting result batches - batches, err := rows.(*snowflakeRows).GetArrowBatches() - if err != nil { - t.Error(err) - } - numBatches := len(batches) - maxWorkers := 10 // enough for 3000 rows - type count struct { - m sync.Mutex - recVal int - metaVal int - } - cnt := count{recVal: 0} - var wg sync.WaitGroup - chunks := make(chan int, numBatches) - - // kicking off download workers - each of which will call fetch on a different result batch - for w := 1; w <= maxWorkers; w++ { - wg.Add(1) - go func(wg *sync.WaitGroup, chunks <-chan int) { - defer wg.Done() - - for i := range chunks { - rec, err := batches[i].Fetch() - if err != nil { - t.Error(err) - } - for _, r := range *rec { + // getting result batches + batches, err := rows.(*snowflakeRows).GetArrowBatches() + if err != nil { + t.Error(err) + } + numBatches := len(batches) + maxWorkers := 10 // enough for 3000 rows + type count struct { + m sync.Mutex + recVal int + metaVal int + } + cnt := count{recVal: 0} + var wg sync.WaitGroup + chunks := make(chan int, numBatches) + + // kicking off download workers - each of which will call fetch on a different result batch + for w := 1; w <= maxWorkers; w++ { + wg.Add(1) + go func(wg *sync.WaitGroup, chunks <-chan int) { + defer wg.Done() + + for i := range chunks { + rec, err := batches[i].Fetch() + if err != nil { + t.Error(err) + } + for _, r := range *rec { + cnt.m.Lock() + cnt.recVal += int(r.NumRows()) + cnt.m.Unlock() + r.Release() + } cnt.m.Lock() - cnt.recVal += int(r.NumRows()) + cnt.metaVal += batches[i].rowCount cnt.m.Unlock() - r.Release() } - cnt.m.Lock() - cnt.metaVal += batches[i].rowCount - cnt.m.Unlock() - } - }(&wg, chunks) - } - for j := 0; j < numBatches; j++ { - chunks <- j - } - close(chunks) + }(&wg, chunks) + } + for j := 0; j < numBatches; j++ { + chunks <- j + } + close(chunks) - // wait for workers to finish fetching and check row counts - wg.Wait() - if cnt.recVal != numrows { - t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) - } - if cnt.metaVal != numrows { - t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) - } + // wait for workers to finish fetching and check row counts + wg.Wait() + if cnt.recVal != numrows { + t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) + } + if cnt.metaVal != numrows { + t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) + } + }) } func TestWithArrowBatchesAsync(t *testing.T) { - ctx := WithAsyncMode(context.Background()) - ctx = WithArrowBatches(ctx) - numrows := 50000 // approximately 10 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := WithAsyncMode(context.Background()) + ctx = WithArrowBatches(ctx) + numrows := 50000 // approximately 10 ArrowBatch objects - pool := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer pool.AssertSize(t, 0) - ctx = WithArrowAllocator(ctx, pool) + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx = WithArrowAllocator(ctx, pool) - query := fmt.Sprintf(selectRandomGenerator, numrows) - rows, err := sc.QueryContext(ctx, query, []driver.NamedValue{}) - if err != nil { - t.Error(err) - } - defer rows.Close() + query := fmt.Sprintf(selectRandomGenerator, numrows) + rows, err := sct.sc.QueryContext(ctx, query, []driver.NamedValue{}) + if err != nil { + t.Error(err) + } + defer rows.Close() - // getting result batches - // this will fail if GetArrowBatches() is not a blocking call - batches, err := rows.(*snowflakeRows).GetArrowBatches() - if err != nil { - t.Error(err) - } - numBatches := len(batches) - maxWorkers := 10 - type count struct { - m sync.Mutex - recVal int - metaVal int - } - cnt := count{recVal: 0} - var wg sync.WaitGroup - chunks := make(chan int, numBatches) - - // kicking off download workers - each of which will call fetch on a different result batch - for w := 1; w <= maxWorkers; w++ { - wg.Add(1) - go func(wg *sync.WaitGroup, chunks <-chan int) { - defer wg.Done() - - for i := range chunks { - rec, err := batches[i].Fetch() - if err != nil { - t.Error(err) - } - for _, r := range *rec { + // getting result batches + // this will fail if GetArrowBatches() is not a blocking call + batches, err := rows.(*snowflakeRows).GetArrowBatches() + if err != nil { + t.Error(err) + } + numBatches := len(batches) + maxWorkers := 10 + type count struct { + m sync.Mutex + recVal int + metaVal int + } + cnt := count{recVal: 0} + var wg sync.WaitGroup + chunks := make(chan int, numBatches) + + // kicking off download workers - each of which will call fetch on a different result batch + for w := 1; w <= maxWorkers; w++ { + wg.Add(1) + go func(wg *sync.WaitGroup, chunks <-chan int) { + defer wg.Done() + + for i := range chunks { + rec, err := batches[i].Fetch() + if err != nil { + t.Error(err) + } + for _, r := range *rec { + cnt.m.Lock() + cnt.recVal += int(r.NumRows()) + cnt.m.Unlock() + r.Release() + } cnt.m.Lock() - cnt.recVal += int(r.NumRows()) + cnt.metaVal += batches[i].rowCount cnt.m.Unlock() - r.Release() } - cnt.m.Lock() - cnt.metaVal += batches[i].rowCount - cnt.m.Unlock() - } - }(&wg, chunks) - } - for j := 0; j < numBatches; j++ { - chunks <- j - } - close(chunks) + }(&wg, chunks) + } + for j := 0; j < numBatches; j++ { + chunks <- j + } + close(chunks) - // wait for workers to finish fetching and check row counts - wg.Wait() - if cnt.recVal != numrows { - t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) - } - if cnt.metaVal != numrows { - t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) - } + // wait for workers to finish fetching and check row counts + wg.Wait() + if cnt.recVal != numrows { + t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) + } + if cnt.metaVal != numrows { + t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) + } + }) } func TestQueryArrowStream(t *testing.T) { - ctx := context.Background() - numrows := 50000 // approximately 10 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - query := fmt.Sprintf(selectRandomGenerator, numrows) - loader, err := sc.QueryArrowStream(ctx, query) - if err != nil { - t.Error(err) - } - - if loader.TotalRows() != int64(numrows) { - t.Errorf("total numrows did not match expected, wanted %v, got %v", numrows, loader.TotalRows()) - } - - batches, err := loader.GetBatches() - if err != nil { - t.Error(err) - } - - numBatches := len(batches) - maxWorkers := 8 - chunks := make(chan int, numBatches) - total := int64(0) - meta := int64(0) + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + numrows := 50000 // approximately 10 ArrowBatch objects - var wg sync.WaitGroup - wg.Add(maxWorkers) - - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - - for w := 0; w < maxWorkers; w++ { - go func() { - defer wg.Done() + query := fmt.Sprintf(selectRandomGenerator, numrows) + loader, err := sct.sc.QueryArrowStream(ctx, query) + if err != nil { + t.Error(err) + } - for i := range chunks { - r, err := batches[i].GetStream(ctx) - if err != nil { - t.Error(err) - continue - } - rdr, err := ipc.NewReader(r, ipc.WithAllocator(mem)) - if err != nil { - t.Errorf("Error creating IPC reader for stream %d: %s", i, err) - r.Close() - continue - } + if loader.TotalRows() != int64(numrows) { + t.Errorf("total numrows did not match expected, wanted %v, got %v", numrows, loader.TotalRows()) + } - for rdr.Next() { - rec := rdr.Record() - atomic.AddInt64(&total, rec.NumRows()) - } + batches, err := loader.GetBatches() + if err != nil { + t.Error(err) + } - if rdr.Err() != nil { - t.Error(rdr.Err()) - } - rdr.Release() - if err := r.Close(); err != nil { - t.Error(err) + numBatches := len(batches) + maxWorkers := 8 + chunks := make(chan int, numBatches) + total := int64(0) + meta := int64(0) + + var wg sync.WaitGroup + wg.Add(maxWorkers) + + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + for w := 0; w < maxWorkers; w++ { + go func() { + defer wg.Done() + + for i := range chunks { + r, err := batches[i].GetStream(ctx) + if err != nil { + t.Error(err) + continue + } + rdr, err := ipc.NewReader(r, ipc.WithAllocator(mem)) + if err != nil { + t.Errorf("Error creating IPC reader for stream %d: %s", i, err) + r.Close() + continue + } + + for rdr.Next() { + rec := rdr.Record() + atomic.AddInt64(&total, rec.NumRows()) + } + + if rdr.Err() != nil { + t.Error(rdr.Err()) + } + rdr.Release() + if err := r.Close(); err != nil { + t.Error(err) + } + atomic.AddInt64(&meta, batches[i].NumRows()) } - atomic.AddInt64(&meta, batches[i].NumRows()) - } - }() - } + }() + } - for j := 0; j < numBatches; j++ { - chunks <- j - } - close(chunks) - wg.Wait() + for j := 0; j < numBatches; j++ { + chunks <- j + } + close(chunks) + wg.Wait() - if total != int64(numrows) { - t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, total) - } - if meta != int64(numrows) { - t.Errorf("number of rows from batch metadata didn't match. expected: %v, got: %v", numrows, total) - } + if total != int64(numrows) { + t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, total) + } + if meta != int64(numrows) { + t.Errorf("number of rows from batch metadata didn't match. expected: %v, got: %v", numrows, total) + } + }) } diff --git a/connection_test.go b/connection_test.go index 2b4ce5750..f40fd1ced 100644 --- a/connection_test.go +++ b/connection_test.go @@ -398,70 +398,49 @@ func TestPrivateLink(t *testing.T) { } func TestGetQueryStatus(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() - if _, err = sc.Exec(`create or replace table ut_conn(c1 number, c2 string) + sct.mustExec(`create or replace table ut_conn(c1 number, c2 string) as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 from - table(generator(rowcount => 100)))`, nil); err != nil { - t.Error(err) - } + table(generator(rowcount => 100)))`, nil) - rows, err := sc.QueryContext(ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) - if err != nil { - t.Error(err) - } - qid := rows.(SnowflakeResult).GetQueryID() + rows, err := sct.sc.QueryContext(ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) + if err != nil { + t.Error(err) + } + qid := rows.(SnowflakeResult).GetQueryID() - // use conn as type holder for SnowflakeConnection placeholder - var conn interface{} = sc - qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(ctx, qid) - if err != nil { - t.Errorf("failed to get query status err = %s", err.Error()) - return - } - if qStatus == nil { - t.Error("there was no query status returned") - return - } + // use conn as type holder for SnowflakeConnection placeholder + var conn interface{} = sct.sc + qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(ctx, qid) + if err != nil { + t.Errorf("failed to get query status err = %s", err.Error()) + return + } + if qStatus == nil { + t.Error("there was no query status returned") + return + } - if qStatus.ErrorCode != "" || qStatus.ScanBytes != 2048 || qStatus.ProducedRows != 10 { - t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v", - qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows) - return - } + if qStatus.ErrorCode != "" || qStatus.ScanBytes != 2048 || qStatus.ProducedRows != 10 { + t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v", + qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows) + return + } + }) } func TestGetInvalidQueryStatus(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + sct.sc.rest.RequestTimeout = 1 * time.Second - sc.rest.RequestTimeout = 1 * time.Second - - qStatus, err := sc.checkQueryStatus(ctx, "1234") - if err == nil || qStatus != nil { - t.Error("expected an error") - } + qStatus, err := sct.sc.checkQueryStatus(ctx, "1234") + if err == nil || qStatus != nil { + t.Error("expected an error") + } + }) } func TestExecWithServerSideError(t *testing.T) { @@ -575,156 +554,112 @@ func executeQueryAndConfirmMessage(db *sql.DB, query string, expectedErrorTable } func TestQueryArrowStreamError(t *testing.T) { - ctx := context.Background() - numrows := 50000 // approximately 10 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - query := fmt.Sprintf(selectRandomGenerator, numrows) - sr := &snowflakeRestful{ - FuncPostQuery: postQueryTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 10, - } - sc.rest = sr - _, err = sc.QueryArrowStream(ctx, query) - if err == nil { - t.Error("should have raised an error") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + numrows := 50000 // approximately 10 ArrowBatch objects + query := fmt.Sprintf(selectRandomGenerator, numrows) + sr := &snowflakeRestful{ + FuncPostQuery: postQueryTest, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 10, + } + sc := sct.sc + sc.rest = sr + _, err := sc.QueryArrowStream(ctx, query) + if err == nil { + t.Error("should have raised an error") + } - sc.rest.FuncPostQuery = postQueryFail - _, err = sc.QueryArrowStream(ctx, query) - if err == nil { - t.Error("should have raised an error") - } - _, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } + sc.rest.FuncPostQuery = postQueryFail + _, err = sc.QueryArrowStream(ctx, query) + if err == nil { + t.Error("should have raised an error") + } + _, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + }) } func TestExecContextError(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() - sr := &snowflakeRestful{ - FuncPostQuery: postQueryTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 10, - } + sr := &snowflakeRestful{ + FuncPostQuery: postQueryTest, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 10, + } - sc.rest = sr + sc := sct.sc + sc.rest = sr - _, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } + _, err := sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } - sc.rest.FuncPostQuery = postQueryFail - _, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } - _, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } + sc.rest.FuncPostQuery = postQueryFail + _, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } + _, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + }) } func TestQueryContextError(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() - sr := &snowflakeRestful{ - FuncPostQuery: postQueryTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 10, - } + sr := &snowflakeRestful{ + FuncPostQuery: postQueryTest, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 10, + } - sc.rest = sr + sc := sct.sc + sc.rest = sr - _, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } + _, err := sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } - sc.rest.FuncPostQuery = postQueryFail - _, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } - _, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } + sc.rest.FuncPostQuery = postQueryFail + _, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } + _, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + }) } func TestPrepareQuery(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - _, err = sc.Prepare("SELECT 1") + runSnowflakeConnTest(t, func(sct *SCTest) { + _, err := sct.sc.Prepare("SELECT 1") - if err != nil { - t.Fatalf("failed to prepare query. err: %v", err) - } - sc.Close() + if err != nil { + t.Fatalf("failed to prepare query. err: %v", err) + } + sct.sc.Close() + }) } func TestBeginCreatesTransaction(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - tx, _ := sc.Begin() - if tx == nil { - t.Fatal("should have created a transaction with connection") - } - sc.Close() + runSnowflakeConnTest(t, func(sct *SCTest) { + tx, _ := sct.sc.Begin() + if tx == nil { + t.Fatal("should have created a transaction with connection") + } + sct.sc.Close() + }) } diff --git a/converter_test.go b/converter_test.go index e48b3e503..571edde04 100644 --- a/converter_test.go +++ b/converter_test.go @@ -1242,125 +1242,96 @@ func TestArrowToRecord(t *testing.T) { } func TestTimestampLTZLocation(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - src := "1549491451.123456789" - var dest driver.Value - loc, _ := time.LoadLocation(PSTLocation) - if err = stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, loc); err != nil { - t.Errorf("unexpected error: %v", err) - } - ts, ok := dest.(time.Time) - if !ok { - t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) - } - if ts.Location() != loc { - t.Errorf("expected location to be %v, got '%v'", loc, ts.Location()) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + src := "1549491451.123456789" + var dest driver.Value + loc, _ := time.LoadLocation(PSTLocation) + if err := stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, loc); err != nil { + t.Errorf("unexpected error: %v", err) + } + ts, ok := dest.(time.Time) + if !ok { + t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) + } + if ts.Location() != loc { + t.Errorf("expected location to be %v, got '%v'", loc, ts.Location()) + } - if err = stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, nil); err != nil { - t.Errorf("unexpected error: %v", err) - } - ts, ok = dest.(time.Time) - if !ok { - t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) - } - if ts.Location() != time.Local { - t.Errorf("expected location to be local, got '%v'", ts.Location()) - } + if err := stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, nil); err != nil { + t.Errorf("unexpected error: %v", err) + } + ts, ok = dest.(time.Time) + if !ok { + t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) + } + if ts.Location() != time.Local { + t.Errorf("expected location to be local, got '%v'", ts.Location()) + } + }) } func TestSmallTimestampBinding(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - timeValue, err := time.Parse("2006-01-02 15:04:05", "1600-10-10 10:10:10") - if err != nil { - t.Fatalf("failed to parse time: %v", err) - } - parameters := []driver.NamedValue{ - {Ordinal: 1, Value: DataTypeTimestampNtz}, - {Ordinal: 2, Value: timeValue}, - } - - rows, err := sc.QueryContext(ctx, "SELECT ?", parameters) - if err != nil { - t.Fatalf("failed to run query: %v", err) - } - defer rows.Close() + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + timeValue, err := time.Parse("2006-01-02 15:04:05", "1600-10-10 10:10:10") + if err != nil { + t.Fatalf("failed to parse time: %v", err) + } + parameters := []driver.NamedValue{ + {Ordinal: 1, Value: DataTypeTimestampNtz}, + {Ordinal: 2, Value: timeValue}, + } - scanValues := make([]driver.Value, 1) - for { - if err := rows.Next(scanValues); err == io.EOF { - break - } else if err != nil { + rows, err := sct.sc.QueryContext(ctx, "SELECT ?", parameters) + if err != nil { t.Fatalf("failed to run query: %v", err) } - if scanValues[0] != timeValue { - t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + defer rows.Close() + + scanValues := make([]driver.Value, 1) + for { + if err := rows.Next(scanValues); err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to run query: %v", err) + } + if scanValues[0] != timeValue { + t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + } } - } + }) } func TestLargeTimestampBinding(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - timeValue, err := time.Parse("2006-01-02 15:04:05", "9000-10-10 10:10:10") - if err != nil { - t.Fatalf("failed to parse time: %v", err) - } - parameters := []driver.NamedValue{ - {Ordinal: 1, Value: DataTypeTimestampNtz}, - {Ordinal: 2, Value: timeValue}, - } - - rows, err := sc.QueryContext(ctx, "SELECT ?", parameters) - if err != nil { - t.Fatalf("failed to run query: %v", err) - } - defer rows.Close() + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + timeValue, err := time.Parse("2006-01-02 15:04:05", "9000-10-10 10:10:10") + if err != nil { + t.Fatalf("failed to parse time: %v", err) + } + parameters := []driver.NamedValue{ + {Ordinal: 1, Value: DataTypeTimestampNtz}, + {Ordinal: 2, Value: timeValue}, + } - scanValues := make([]driver.Value, 1) - for { - if err := rows.Next(scanValues); err == io.EOF { - break - } else if err != nil { + rows, err := sct.sc.QueryContext(ctx, "SELECT ?", parameters) + if err != nil { t.Fatalf("failed to run query: %v", err) } - if scanValues[0] != timeValue { - t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + defer rows.Close() + + scanValues := make([]driver.Value, 1) + for { + if err := rows.Next(scanValues); err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to run query: %v", err) + } + if scanValues[0] != timeValue { + t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + } } - } + }) } func TestTimeTypeValueToString(t *testing.T) { diff --git a/driver_test.go b/driver_test.go index 653c62777..c84068030 100644 --- a/driver_test.go +++ b/driver_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/rsa" "database/sql" + "database/sql/driver" "flag" "fmt" "net/http" @@ -320,6 +321,34 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { return stmt } +type SCTest struct { + *testing.T + sc *snowflakeConn +} + +func (sct *SCTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) +} + +func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result { + result, err := sct.sc.Exec(query, args) + if err != nil { + sct.fail("exec", query, err) + } + return result +} + +func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows { + rows, err := sct.sc.Query(query, args) + if err != nil { + sct.fail("query", query, err) + } + return rows +} + func runDBTest(t *testing.T, test func(dbt *DBTest)) { conn := openConn(t) defer conn.Close() @@ -328,7 +357,24 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) { test(dbt) } -func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { +//func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { +// config, err := ParseDSN(dsn) +// if err != nil { +// t.Error(err) +// } +// sc, err := buildSnowflakeConn(context.Background(), *config) +// if err != nil { +// t.Fatal(err) +// } +// defer sc.Close() +// if err = authenticateWithConfig(sc); err != nil { +// t.Fatal(err) +// } +// +// test(sc) +//} + +func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { config, err := ParseDSN(dsn) if err != nil { t.Error(err) @@ -342,7 +388,8 @@ func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { t.Fatal(err) } - test(sc) + sct := &SCTest{t, sc} + test(sct) } func runningOnAWS() bool { diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index 633bb5019..a53986576 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -28,33 +28,24 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) { if runningOnGithubAction() { t.Skip("Should be run against an account in AWS EU North1 region.") } - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } - if err = sfa.transferAccelerateConfig(); err != nil { - var ae smithy.APIError - if errors.As(err, &ae) { - if ae.ErrorCode() == "MethodNotAllowed" { - t.Fatalf("should have ignored 405 error: %v", err) + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } + if err := sfa.transferAccelerateConfig(); err != nil { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "MethodNotAllowed" { + t.Fatalf("should have ignored 405 error: %v", err) + } } } - } + }) } func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { @@ -88,455 +79,374 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { }) } func TestUnitGetLocalFilePathFromCommand(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } - testcases := []tcFilePath{ - {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, - {"PUT 'file:///tmp/my_data_file.txt' @~ overwrite=true", "/tmp/my_data_file.txt"}, - {"PUT file:///tmp/sub_dir/my_data_file.txt\n @~ overwrite=true", "/tmp/sub_dir/my_data_file.txt"}, - {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, - {"", ""}, - {"PUT 'file2:///tmp/my_data_file.txt' @~ overwrite=true", ""}, - } - for _, test := range testcases { - t.Run(test.command, func(t *testing.T) { - path := sfa.getLocalFilePathFromCommand(test.command) - if path != test.path { - t.Fatalf("unexpected file path. expected: %v, but got: %v", test.path, path) - } - }) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } + testcases := []tcFilePath{ + {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, + {"PUT 'file:///tmp/my_data_file.txt' @~ overwrite=true", "/tmp/my_data_file.txt"}, + {"PUT file:///tmp/sub_dir/my_data_file.txt\n @~ overwrite=true", "/tmp/sub_dir/my_data_file.txt"}, + {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, + {"", ""}, + {"PUT 'file2:///tmp/my_data_file.txt' @~ overwrite=true", ""}, + } + for _, test := range testcases { + t.Run(test.command, func(t *testing.T) { + path := sfa.getLocalFilePathFromCommand(test.command) + if path != test.path { + t.Fatalf("unexpected file path. expected: %v, but got: %v", test.path, path) + } + }) + } + }) } func TestUnitProcessFileCompressionType(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - } - testcases := []struct { - srcCompression string - }{ - {"none"}, - {"auto_detect"}, - {"gzip"}, - } + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + } + testcases := []struct { + srcCompression string + }{ + {"none"}, + {"auto_detect"}, + {"gzip"}, + } - for _, test := range testcases { - t.Run(test.srcCompression, func(t *testing.T) { - sfa.srcCompression = test.srcCompression - err = sfa.processFileCompressionType() - if err != nil { - t.Fatalf("failed to process file compression") - } - }) - } + for _, test := range testcases { + t.Run(test.srcCompression, func(t *testing.T) { + sfa.srcCompression = test.srcCompression + err := sfa.processFileCompressionType() + if err != nil { + t.Fatalf("failed to process file compression") + } + }) + } - // test invalid compression type error - sfa.srcCompression = "gz" - data := &execResponseData{ - SQLState: "S00087", - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - } - sfa.data = data - err = sfa.processFileCompressionType() - if err == nil { - t.Fatal("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCompressionNotSupported { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCompressionNotSupported, driverErr.Number) - } + // test invalid compression type error + sfa.srcCompression = "gz" + data := &execResponseData{ + SQLState: "S00087", + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + } + sfa.data = data + err := sfa.processFileCompressionType() + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCompressionNotSupported { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCompressionNotSupported, driverErr.Number) + } + }) } func TestParseCommandWithInvalidStageLocation(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } - err = sfa.parseCommand() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrInvalidStageLocation { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageLocation, driverErr.Number) - } + err := sfa.parseCommand() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrInvalidStageLocation { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageLocation, driverErr.Number) + } + }) } func TestParseCommandEncryptionMaterialMismatchError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - mockEncMaterial1 := snowflakeFileEncryption{ - QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", - QueryID: "01abc874-0406-1bf0-0000-53b10668e056", - SMKID: 92019681909886, - } + mockEncMaterial1 := snowflakeFileEncryption{ + QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", + QueryID: "01abc874-0406-1bf0-0000-53b10668e056", + SMKID: 92019681909886, + } - mockEncMaterial2 := snowflakeFileEncryption{ - QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", - QueryID: "01abc874-0406-1bf0-0000-53b10668e056", - SMKID: 92019681909886, - } + mockEncMaterial2 := snowflakeFileEncryption{ + QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", + QueryID: "01abc874-0406-1bf0-0000-53b10668e056", + SMKID: 92019681909886, + } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: []string{"/tmp/uploads"}, - EncryptionMaterial: encryptionWrapper{ - snowflakeFileEncryption: mockEncMaterial1, - EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1, mockEncMaterial2}, + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: []string{"/tmp/uploads"}, + EncryptionMaterial: encryptionWrapper{ + snowflakeFileEncryption: mockEncMaterial1, + EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1, mockEncMaterial2}, + }, }, - }, - } + } - err = sfa.parseCommand() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrInternalNotMatchEncryptMaterial { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInternalNotMatchEncryptMaterial, driverErr.Number) - } + err := sfa.parseCommand() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrInternalNotMatchEncryptMaterial { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInternalNotMatchEncryptMaterial, driverErr.Number) + } + }) } func TestParseCommandInvalidStorageClientException(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - tmpDir, err := os.MkdirTemp("", "abc") - if err != nil { - t.Error(err) - } - mockEncMaterial1 := snowflakeFileEncryption{ - QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", - QueryID: "01abc874-0406-1bf0-0000-53b10668e056", - SMKID: 92019681909886, - } + tmpDir, err := os.MkdirTemp("", "abc") + if err != nil { + t.Error(err) + } + mockEncMaterial1 := snowflakeFileEncryption{ + QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", + QueryID: "01abc874-0406-1bf0-0000-53b10668e056", + SMKID: 92019681909886, + } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: []string{"/tmp/uploads"}, - LocalLocation: tmpDir, - EncryptionMaterial: encryptionWrapper{ - snowflakeFileEncryption: mockEncMaterial1, - EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1}, + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: []string{"/tmp/uploads"}, + LocalLocation: tmpDir, + EncryptionMaterial: encryptionWrapper{ + snowflakeFileEncryption: mockEncMaterial1, + EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1}, + }, }, - }, - options: &SnowflakeFileTransferOptions{ - DisablePutOverwrite: false, - }, - } + options: &SnowflakeFileTransferOptions{ + DisablePutOverwrite: false, + }, + } - err = sfa.parseCommand() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrInvalidStageFs { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageFs, driverErr.Number) - } + err = sfa.parseCommand() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrInvalidStageFs { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageFs, driverErr.Number) + } + }) } func TestInitFileMetadataError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: []string{"fileDoesNotExist.txt"}, - data: &execResponseData{ - SQLState: "123456", - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - }, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: []string{"fileDoesNotExist.txt"}, + data: &execResponseData{ + SQLState: "123456", + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + }, + } - err = sfa.initFileMetadata() - if err == nil { - t.Fatal("should have raised an error") - } + err := sfa.initFileMetadata() + if err == nil { + t.Fatal("should have raised an error") + } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrFileNotExists { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) - } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrFileNotExists { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) + } - tmpDir, err := os.MkdirTemp("", "data") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) - sfa.srcFiles = []string{tmpDir} + tmpDir, err := os.MkdirTemp("", "data") + if err != nil { + t.Error(err) + } + defer os.RemoveAll(tmpDir) + sfa.srcFiles = []string{tmpDir} - err = sfa.initFileMetadata() - if err == nil { - t.Fatal("should have raised an error") - } + err = sfa.initFileMetadata() + if err == nil { + t.Fatal("should have raised an error") + } - driverErr, ok = err.(*SnowflakeError) - if !ok || driverErr.Number != ErrFileNotExists { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) - } + driverErr, ok = err.(*SnowflakeError) + if !ok || driverErr.Number != ErrFileNotExists { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) + } + }) } func TestUpdateMetadataWithPresignedUrl(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - info := execResponseStageInfo{ - Location: "gcs-blob/storage/users/456/", - LocationType: "GCS", - } + runSnowflakeConnTest(t, func(sct *SCTest) { + info := execResponseStageInfo{ + Location: "gcs-blob/storage/users/456/", + LocationType: "GCS", + } - dir, err := os.Getwd() - if err != nil { - t.Error(err) - } + dir, err := os.Getwd() + if err != nil { + t.Error(err) + } - testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" - - presignedURLMock := func(_ context.Context, _ *snowflakeRestful, - _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, - requestID UUID, _ *Config) (*execResponse, error) { - // ensure the same requestID from context is used - if len(requestID) == 0 { - t.Fatal("requestID is empty") - } - dd := &execResponseData{ - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - Command: string(uploadCommand), - StageInfo: execResponseStageInfo{ - LocationType: "GCS", - Location: "gcspuscentral1-4506459564-stage/users/456", - Path: "users/456", - Region: "US_CENTRAL1", - PresignedURL: testURL, - }, + testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" + + presignedURLMock := func(_ context.Context, _ *snowflakeRestful, + _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, + requestID UUID, _ *Config) (*execResponse, error) { + // ensure the same requestID from context is used + if len(requestID) == 0 { + t.Fatal("requestID is empty") + } + dd := &execResponseData{ + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + Command: string(uploadCommand), + StageInfo: execResponseStageInfo{ + LocationType: "GCS", + Location: "gcspuscentral1-4506459564-stage/users/456", + Path: "users/456", + Region: "US_CENTRAL1", + PresignedURL: testURL, + }, + } + return &execResponse{ + Data: *dd, + Message: "", + Code: "0", + Success: true, + }, nil } - return &execResponse{ - Data: *dd, - Message: "", - Code: "0", - Success: true, - }, nil - } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) - if err != nil { - t.Error(err) - } - uploadMeta := fileMetadata{ - name: "data1.txt.gz", - stageLocationType: "GCS", - noSleepingTime: true, - client: gcsCli, - sha256Digest: "123456789abcdef", - stageInfo: &info, - dstFileName: "data1.txt.gz", - srcFileName: path.Join(dir, "/test_data/data1.txt"), - overwrite: true, - options: &SnowflakeFileTransferOptions{ - MultiPartThreshold: dataSizeThreshold, - }, - } + gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + if err != nil { + t.Error(err) + } + uploadMeta := fileMetadata{ + name: "data1.txt.gz", + stageLocationType: "GCS", + noSleepingTime: true, + client: gcsCli, + sha256Digest: "123456789abcdef", + stageInfo: &info, + dstFileName: "data1.txt.gz", + srcFileName: path.Join(dir, "/test_data/data1.txt"), + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } - sc.rest.FuncPostQuery = presignedURLMock - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - command: "put file:///tmp/test_data/data1.txt @~", - stageLocationType: gcsClient, - fileMetadata: []*fileMetadata{&uploadMeta}, - } + sct.sc.rest.FuncPostQuery = presignedURLMock + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + command: "put file:///tmp/test_data/data1.txt @~", + stageLocationType: gcsClient, + fileMetadata: []*fileMetadata{&uploadMeta}, + } - err = sfa.updateFileMetadataWithPresignedURL() - if err != nil { - t.Error(err) - } - if testURL != sfa.fileMetadata[0].presignedURL.String() { - t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) - } + err = sfa.updateFileMetadataWithPresignedURL() + if err != nil { + t.Error(err) + } + if testURL != sfa.fileMetadata[0].presignedURL.String() { + t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) + } + }) } func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - info := execResponseStageInfo{ - Location: "gcs-blob/storage/users/456/", - LocationType: "GCS", - } + runSnowflakeConnTest(t, func(sct *SCTest) { + info := execResponseStageInfo{ + Location: "gcs-blob/storage/users/456/", + LocationType: "GCS", + } - dir, err := os.Getwd() - if err != nil { - t.Error(err) - } + dir, err := os.Getwd() + if err != nil { + t.Error(err) + } - testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" + testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) - if err != nil { - t.Error(err) - } - downloadMeta := fileMetadata{ - name: "data1.txt.gz", - stageLocationType: "GCS", - noSleepingTime: true, - client: gcsCli, - stageInfo: &info, - dstFileName: "data1.txt.gz", - overwrite: true, - srcFileName: "data1.txt.gz", - localLocation: dir, - } + gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + if err != nil { + t.Error(err) + } + downloadMeta := fileMetadata{ + name: "data1.txt.gz", + stageLocationType: "GCS", + noSleepingTime: true, + client: gcsCli, + stageInfo: &info, + dstFileName: "data1.txt.gz", + overwrite: true, + srcFileName: "data1.txt.gz", + localLocation: dir, + } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: downloadCommand, - command: "get @~/data1.txt.gz file:///tmp/testData", - stageLocationType: gcsClient, - fileMetadata: []*fileMetadata{&downloadMeta}, - presignedURLs: []string{testURL}, - } + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: downloadCommand, + command: "get @~/data1.txt.gz file:///tmp/testData", + stageLocationType: gcsClient, + fileMetadata: []*fileMetadata{&downloadMeta}, + presignedURLs: []string{testURL}, + } - err = sfa.updateFileMetadataWithPresignedURL() - if err != nil { - t.Error(err) - } - if testURL != sfa.fileMetadata[0].presignedURL.String() { - t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) - } + err = sfa.updateFileMetadataWithPresignedURL() + if err != nil { + t.Error(err) + } + if testURL != sfa.fileMetadata[0].presignedURL.String() { + t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) + } + }) } func TestUpdateMetadataWithPresignedUrlError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - sfa := &snowflakeFileTransferAgent{ - sc: sc, - command: "get @~/data1.txt.gz file:///tmp/testData", - stageLocationType: gcsClient, - data: &execResponseData{ - SQLState: "123456", - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - }, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + command: "get @~/data1.txt.gz file:///tmp/testData", + stageLocationType: gcsClient, + data: &execResponseData{ + SQLState: "123456", + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + }, + } - err = sfa.updateFileMetadataWithPresignedURL() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrCommandNotRecognized { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCommandNotRecognized, driverErr.Number) - } + err := sfa.updateFileMetadataWithPresignedURL() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrCommandNotRecognized { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCommandNotRecognized, driverErr.Number) + } + }) } func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { diff --git a/heartbeat_test.go b/heartbeat_test.go index 839d50986..11733745d 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -3,54 +3,44 @@ package gosnowflake import ( - "context" "testing" ) func TestUnitPostHeartbeat(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - // send heartbeat call and renew expired session - sr := &snowflakeRestful{ - FuncPost: postTestRenew, - FuncRenewSession: renewSessionTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 0, - } - heartbeat := &heartbeat{ - restful: sr, - } - err = heartbeat.heartbeatMain() - if err != nil { - t.Fatalf("failed to heartbeat and renew session. err: %v", err) - } + // send heartbeat call and renew expired session + sr := &snowflakeRestful{ + FuncPost: postTestRenew, + FuncRenewSession: renewSessionTest, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 0, + } + heartbeat := &heartbeat{ + restful: sr, + } + err := heartbeat.heartbeatMain() + if err != nil { + t.Fatalf("failed to heartbeat and renew session. err: %v", err) + } - heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON - err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } + heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON + err = heartbeat.heartbeatMain() + if err == nil { + t.Fatal("should have failed") + } - heartbeat.restful.FuncPost = postTestAppForbiddenError - err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToHeartbeat { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToHeartbeat, driverErr.Number) - } + heartbeat.restful.FuncPost = postTestAppForbiddenError + err = heartbeat.heartbeatMain() + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrFailedToHeartbeat { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToHeartbeat, driverErr.Number) + } + }) } diff --git a/htap_test.go b/htap_test.go index e2e71eb20..aac102994 100644 --- a/htap_test.go +++ b/htap_test.go @@ -257,14 +257,12 @@ func TestAddingQcesWithDifferentId(t *testing.T) { } func TestAddingQueryContextCacheEntry(t *testing.T) { - runSnowflakeConnTest(t, func(sc *snowflakeConn) { + runSnowflakeConnTest(t, func(sct *SCTest) { t.Run("First query (may be on empty cache)", func(t *testing.T) { - entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries)) - copy(entriesBefore, sc.queryContextCache.entries) - if _, err := sc.Query("SELECT 1", nil); err != nil { - t.Fatalf("cannot query. %v", err) - } - entriesAfter := sc.queryContextCache.entries + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) + sct.mustQuery("SELECT 1", nil) + entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") @@ -272,15 +270,13 @@ func TestAddingQueryContextCacheEntry(t *testing.T) { }) t.Run("Second query (cache should not be empty)", func(t *testing.T) { - entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries)) - copy(entriesBefore, sc.queryContextCache.entries) + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) if len(entriesBefore) == 0 { t.Fatalf("cache should not be empty after first query") } - if _, err := sc.Query("SELECT 2", nil); err != nil { - t.Fatalf("cannot query. %v", err) - } - entriesAfter := sc.queryContextCache.entries + sct.mustQuery("SELECT 2", nil) + entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") diff --git a/multistatement_test.go b/multistatement_test.go index c7a5c5cae..c9affffbb 100644 --- a/multistatement_test.go +++ b/multistatement_test.go @@ -481,115 +481,97 @@ func funcGetQueryRespError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ } func TestUnitHandleMultiExec(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - data := &execResponseData{ - ResultIDs: "", - ResultTypes: "", - } - _, err = sc.handleMultiExec(context.Background(), *data) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrNoResultIDs { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + data := &execResponseData{ + ResultIDs: "", + ResultTypes: "", + } + _, err := sct.sc.handleMultiExec(context.Background(), *data) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrNoResultIDs { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) + } - sr := &snowflakeRestful{ - FuncGet: funcGetQueryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - data = &execResponseData{ - ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", - ResultTypes: "12544,12544", - } - sc.rest = sr - _, err = sc.handleMultiExec(context.Background(), *data) - if err == nil { - t.Fatalf("should have failed") - } + sr := &snowflakeRestful{ + FuncGet: funcGetQueryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + data = &execResponseData{ + ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", + ResultTypes: "12544,12544", + } + sct.sc.rest = sr + _, err = sct.sc.handleMultiExec(context.Background(), *data) + if err == nil { + t.Fatalf("should have failed") + } - sc.rest.FuncGet = funcGetQueryRespError - data.SQLState = "01112" - _, err = sc.handleMultiExec(context.Background(), *data) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok = err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToPostQuery { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) - } + sct.sc.rest.FuncGet = funcGetQueryRespError + data.SQLState = "01112" + _, err = sct.sc.handleMultiExec(context.Background(), *data) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok = err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrFailedToPostQuery { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) + } + }) } func TestUnitHandleMultiQuery(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - data := &execResponseData{ - ResultIDs: "", - ResultTypes: "", - } - rows := new(snowflakeRows) - err = sc.handleMultiQuery(context.Background(), *data, rows) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrNoResultIDs { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) - } - sr := &snowflakeRestful{ - FuncGet: funcGetQueryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - data = &execResponseData{ - ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", - ResultTypes: "12544,12544", - } - sc.rest = sr - err = sc.handleMultiQuery(context.Background(), *data, rows) - if err == nil { - t.Fatalf("should have failed") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + data := &execResponseData{ + ResultIDs: "", + ResultTypes: "", + } + rows := new(snowflakeRows) + err := sct.sc.handleMultiQuery(context.Background(), *data, rows) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrNoResultIDs { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) + } + sr := &snowflakeRestful{ + FuncGet: funcGetQueryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + data = &execResponseData{ + ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", + ResultTypes: "12544,12544", + } + sct.sc.rest = sr + err = sct.sc.handleMultiQuery(context.Background(), *data, rows) + if err == nil { + t.Fatalf("should have failed") + } - sc.rest.FuncGet = funcGetQueryRespError - data.SQLState = "01112" - err = sc.handleMultiQuery(context.Background(), *data, rows) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok = err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToPostQuery { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) - } + sct.sc.rest.FuncGet = funcGetQueryRespError + data.SQLState = "01112" + err = sct.sc.handleMultiQuery(context.Background(), *data, rows) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok = err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrFailedToPostQuery { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) + } + }) } diff --git a/put_get_with_aws_test.go b/put_get_with_aws_test.go index b02fd27a8..8fed591f1 100644 --- a/put_get_with_aws_test.go +++ b/put_get_with_aws_test.go @@ -87,115 +87,103 @@ func TestLoadS3(t *testing.T) { } func TestPutWithInvalidToken(t *testing.T) { - if !runningOnAWS() { - t.Skip("skipping non aws environment") - } - tmpDir, err := os.MkdirTemp("", "aws_put") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) - fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") - originalContents := "123,test1\n456,test2\n" - - var b bytes.Buffer - gzw := gzip.NewWriter(&b) - gzw.Write([]byte(originalContents)) - gzw.Close() - if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { - t.Fatal("could not write to gzip file") - } - - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - tableName := randomString(5) - if _, err = sc.Exec("create or replace table "+tableName+ - " (a int, b string)", nil); err != nil { - t.Fatal(err) - } - defer sc.Exec("drop table "+tableName, nil) + runSnowflakeConnTest(t, func(sct *SCTest) { + if !runningOnAWS() { + t.Skip("skipping non aws environment") + } + tmpDir, err := os.MkdirTemp("", "aws_put") + if err != nil { + t.Error(err) + } + defer os.RemoveAll(tmpDir) + fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") + originalContents := "123,test1\n456,test2\n" + + var b bytes.Buffer + gzw := gzip.NewWriter(&b) + gzw.Write([]byte(originalContents)) + gzw.Close() + if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { + t.Fatal("could not write to gzip file") + } - jsonBody, err := json.Marshal(execRequest{ - SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), - }) - if err != nil { - t.Error(err) - } - headers := getHeaders() - headers[httpHeaderAccept] = headerContentTypeApplicationJSON - data, err := sc.rest.FuncPostQuery( - sc.ctx, sc.rest, &url.Values{}, headers, jsonBody, - sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sc.ctx), sc.cfg) - if err != nil { - t.Fatal(err) - } + tableName := randomString(5) + sct.mustExec("create or replace table "+tableName+" (a int, b string)", nil) + defer sct.mustExec("drop table "+tableName, nil) - s3Util := new(snowflakeS3Client) - s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) - if err != nil { - t.Error(err) - } - client := s3Cli.(*s3.Client) + jsonBody, err := json.Marshal(execRequest{ + SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), + }) + if err != nil { + t.Error(err) + } + headers := getHeaders() + headers[httpHeaderAccept] = headerContentTypeApplicationJSON + sc := sct.sc + data, err := sc.rest.FuncPostQuery( + sc.ctx, sc.rest, &url.Values{}, headers, jsonBody, + sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sc.ctx), sc.cfg) + if err != nil { + t.Fatal(err) + } - s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location) - if err != nil { - t.Error(err) - } - s3Path := s3Loc.s3Path + baseName(fname) + ".gz" + s3Util := new(snowflakeS3Client) + s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) + if err != nil { + t.Error(err) + } + client := s3Cli.(*s3.Client) - f, err := os.Open(fname) - if err != nil { - t.Error(err) - } - defer f.Close() - uploader := manager.NewUploader(client) - if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3Loc.bucketName, - Key: &s3Path, - Body: f, - }); err != nil { - t.Fatal(err) - } + s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location) + if err != nil { + t.Error(err) + } + s3Path := s3Loc.s3Path + baseName(fname) + ".gz" - parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/" - if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3Loc.bucketName, - Key: &parentPath, - Body: f, - }); err == nil { - t.Fatal("should have failed attempting to put file in parent path") - } + f, err := os.Open(fname) + if err != nil { + t.Error(err) + } + defer f.Close() + uploader := manager.NewUploader(client) + if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3Loc.bucketName, + Key: &s3Path, + Body: f, + }); err != nil { + t.Fatal(err) + } - info := execResponseStageInfo{ - Creds: execResponseCredentials{ - AwsID: data.Data.StageInfo.Creds.AwsID, - AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, - }, - } - s3Cli, err = s3Util.createClient(&info, false) - if err != nil { - t.Error(err) - } - client = s3Cli.(*s3.Client) + parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/" + if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3Loc.bucketName, + Key: &parentPath, + Body: f, + }); err == nil { + t.Fatal("should have failed attempting to put file in parent path") + } - uploader = manager.NewUploader(client) - if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3Loc.bucketName, - Key: &s3Path, - Body: f, - }); err == nil { - t.Fatal("should have failed attempting to put with missing aws token") - } + info := execResponseStageInfo{ + Creds: execResponseCredentials{ + AwsID: data.Data.StageInfo.Creds.AwsID, + AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, + }, + } + s3Cli, err = s3Util.createClient(&info, false) + if err != nil { + t.Error(err) + } + client = s3Cli.(*s3.Client) + + uploader = manager.NewUploader(client) + if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3Loc.bucketName, + Key: &s3Path, + Body: f, + }); err == nil { + t.Fatal("should have failed attempting to put with missing aws token") + } + }) } func TestPretendToPutButList(t *testing.T) { @@ -218,50 +206,41 @@ func TestPretendToPutButList(t *testing.T) { t.Fatal("could not write to gzip file") } - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sc := sct.sc + tableName := randomString(5) + if _, err = sc.Exec("create or replace table "+tableName+ + " (a int, b string)", nil); err != nil { + t.Fatal(err) + } + defer sc.Exec("drop table "+tableName, nil) - tableName := randomString(5) - if _, err = sc.Exec("create or replace table "+tableName+ - " (a int, b string)", nil); err != nil { - t.Fatal(err) - } - defer sc.Exec("drop table "+tableName, nil) + jsonBody, err := json.Marshal(execRequest{ + SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), + }) + if err != nil { + t.Error(err) + } + headers := getHeaders() + headers[httpHeaderAccept] = headerContentTypeApplicationJSON + data, err := sc.rest.FuncPostQuery( + sc.ctx, sc.rest, &url.Values{}, headers, jsonBody, + sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sc.ctx), sc.cfg) + if err != nil { + t.Fatal(err) + } - jsonBody, err := json.Marshal(execRequest{ - SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), + s3Util := new(snowflakeS3Client) + s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) + if err != nil { + t.Error(err) + } + client := s3Cli.(*s3.Client) + if _, err = client.ListBuckets(context.Background(), + &s3.ListBucketsInput{}); err == nil { + t.Fatal("list buckets should fail") + } }) - if err != nil { - t.Error(err) - } - headers := getHeaders() - headers[httpHeaderAccept] = headerContentTypeApplicationJSON - data, err := sc.rest.FuncPostQuery( - sc.ctx, sc.rest, &url.Values{}, headers, jsonBody, - sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sc.ctx), sc.cfg) - if err != nil { - t.Fatal(err) - } - - s3Util := new(snowflakeS3Client) - s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) - if err != nil { - t.Error(err) - } - client := s3Cli.(*s3.Client) - if _, err = client.ListBuckets(context.Background(), - &s3.ListBucketsInput{}); err == nil { - t.Fatal("list buckets should fail") - } } func TestPutGetAWSStage(t *testing.T) { diff --git a/rows_test.go b/rows_test.go index 0d8fbc11c..39285b8fd 100644 --- a/rows_test.go +++ b/rows_test.go @@ -452,40 +452,28 @@ func TestDownloadChunkErrorStatus(t *testing.T) { func TestWithArrowBatchesNotImplementedForResult(t *testing.T) { ctx := WithArrowBatches(context.Background()) - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - defer sc.Close() - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - if _, err = sc.Exec("create or replace table testArrowBatches (a int, b int)", nil); err != nil { - t.Fatal(err) - } - defer sc.Exec("drop table if exists testArrowBatches", nil) + sct.mustExec("create or replace table testArrowBatches (a int, b int)", nil) + defer sct.sc.Exec("drop table if exists testArrowBatches", nil) - result, err := sc.ExecContext(ctx, "insert into testArrowBatches values (1, 2), (3, 4), (5, 6)", []driver.NamedValue{}) - if err != nil { - t.Error(err) - } + result, err := sct.sc.ExecContext(ctx, "insert into testArrowBatches values (1, 2), (3, 4), (5, 6)", []driver.NamedValue{}) + if err != nil { + t.Error(err) + } - _, err = result.(*snowflakeResult).GetArrowBatches() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrNotImplemented { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNotImplemented, driverErr.Number) - } + _, err = result.(*snowflakeResult).GetArrowBatches() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrNotImplemented { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNotImplemented, driverErr.Number) + } + }) } func TestLocationChangesAfterAlterSession(t *testing.T) { diff --git a/telemetry_test.go b/telemetry_test.go index 9a09c2bdb..783ee501c 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -14,128 +14,90 @@ import ( ) func TestTelemetryAddLog(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - st := &snowflakeTelemetry{ - sr: sc.rest, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } - rand.Seed(time.Now().UnixNano()) - randNum := rand.Int() % 10000 - for i := 0; i < randNum; i++ { - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { + st := &snowflakeTelemetry{ + sr: sct.sc.rest, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + rand.Seed(time.Now().UnixNano()) + randNum := rand.Int() % 10000 + for i := 0; i < randNum; i++ { + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + } + if len(st.logs) != randNum%defaultFlushSize { + t.Errorf("length of remaining logs does not match. expected: %v, got: %v", + randNum%defaultFlushSize, len(st.logs)) + } + if err := st.sendBatch(); err != nil { t.Fatal(err) } - } - if len(st.logs) != randNum%defaultFlushSize { - t.Errorf("length of remaining logs does not match. expected: %v, got: %v", - randNum%defaultFlushSize, len(st.logs)) - } - if err = st.sendBatch(); err != nil { - t.Fatal(err) - } + }) } func TestTelemetrySQLException(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - st := &snowflakeTelemetry{ - sr: sc.rest, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } - sc.telemetry = st - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } - if err = sfa.initFileMetadata(); err == nil { - t.Fatal("this should have thrown an error") - } - if len(st.logs) != 1 { - t.Errorf("there should be 1 telemetry data in log. found: %v", len(st.logs)) - } - if sendErr := st.sendBatch(); sendErr != nil { - t.Fatal(sendErr) - } - if len(st.logs) != 0 { - t.Errorf("there should be no telemetry data in log. found: %v", len(st.logs)) - } + st := &snowflakeTelemetry{ + sr: sct.sc.rest, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + sct.sc.telemetry = st + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } + if err := sfa.initFileMetadata(); err == nil { + t.Fatal("this should have thrown an error") + } + if len(st.logs) != 1 { + t.Errorf("there should be 1 telemetry data in log. found: %v", len(st.logs)) + } + if sendErr := st.sendBatch(); sendErr != nil { + t.Fatal(sendErr) + } + if len(st.logs) != 0 { + t.Errorf("there should be no telemetry data in log. found: %v", len(st.logs)) + } + }) } func TestDisableTelemetry(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - config.DisableTelemetry = true - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - if !sc.cfg.DisableTelemetry { - t.Errorf("DisableTelemetry should be true. DisableTelemetry: %v", sc.cfg.DisableTelemetry) - } - if sc.telemetry.enabled { - t.Errorf("telemetry should be disabled.") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + if !sct.sc.cfg.DisableTelemetry { + t.Errorf("DisableTelemetry should be true. DisableTelemetry: %v", sct.sc.cfg.DisableTelemetry) + } + if sct.sc.telemetry.enabled { + t.Errorf("telemetry should be disabled.") + } + }) } func TestEnableTelemetry(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - config.DisableTelemetry = false - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - if sc.cfg.DisableTelemetry { - t.Errorf("DisableTelemetry should be false. DisableTelemetry: %v", sc.cfg.DisableTelemetry) - } - if !sc.telemetry.enabled { - t.Errorf("telemetry should be enabled.") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + if sct.sc.cfg.DisableTelemetry { + t.Errorf("DisableTelemetry should be false. DisableTelemetry: %v", sct.sc.cfg.DisableTelemetry) + } + if !sct.sc.telemetry.enabled { + t.Errorf("telemetry should be enabled.") + } + }) } func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { @@ -143,203 +105,167 @@ func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.UR } func TestTelemetryError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sr := &snowflakeRestful{ - FuncPost: funcPostTelemetryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sr := &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } + err := st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + }) } func TestTelemetryDisabledOnBadResponse(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sr := &snowflakeRestful{ - FuncPost: postTestAppBadGatewayError, - TokenAccessor: getSimpleTokenAccessor(), - } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sr := &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } - if st.enabled == true { - t.Fatal("telemetry should be disabled") - } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err := st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } - st.enabled = true - st.sr.FuncPost = postTestQueryNotExecuting - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } - if st.enabled == true { - t.Fatal("telemetry should be disabled") - } + st.enabled = true + st.sr.FuncPost = postTestQueryNotExecuting + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } - st.enabled = true - st.sr.FuncPost = postTestSuccessButInvalidJSON - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } - if st.enabled == true { - t.Fatal("telemetry should be disabled") - } + st.enabled = true + st.sr.FuncPost = postTestSuccessButInvalidJSON + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } + }) } func TestTelemetryDisabled(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sr := &snowflakeRestful{ - FuncPost: postTestAppBadGatewayError, - TokenAccessor: getSimpleTokenAccessor(), - } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: false, // disable - flushSize: defaultFlushSize, - } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err == nil { - t.Fatal("should have failed") - } - st.enabled = true - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - st.enabled = false - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sr := &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: false, // disable + flushSize: defaultFlushSize, + } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } + st.enabled = true + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + st.enabled = false + err := st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + }) } func TestAddLogError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { - sr := &snowflakeRestful{ - FuncPost: funcPostTelemetryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } + sr := &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: 1, - } + st := &snowflakeTelemetry{ + sr: sr, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: 1, + } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err == nil { - t.Fatal("should have failed") - } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } + }) }