diff --git a/.github/workflows/snyk-issue.yml b/.github/workflows/snyk-issue.yml index 7d076752f..a2d594534 100644 --- a/.github/workflows/snyk-issue.yml +++ b/.github/workflows/snyk-issue.yml @@ -6,8 +6,13 @@ on: concurrency: snyk-issue +permissions: + contents: read + issues: write + pull-requests: write + jobs: - whitesource: + snyk: runs-on: ubuntu-latest steps: - name: checkout action diff --git a/.github/workflows/snyk-pr.yml b/.github/workflows/snyk-pr.yml index 861d59028..156d8e16e 100644 --- a/.github/workflows/snyk-pr.yml +++ b/.github/workflows/snyk-pr.yml @@ -1,10 +1,17 @@ name: snyk-pr + on: pull_request: branches: - master + +permissions: + contents: read + issues: write + pull-requests: write + jobs: - whitesource: + snyk: runs-on: ubuntu-latest permissions: write-all if: ${{ github.event.pull_request.user.login == 'sfc-gh-snyk-sca-sa' }} diff --git a/README.md b/README.md index 97ee31c55..5056d80df 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ The latest driver requires the [Go language](https://golang.org/) 1.19 or higher # Installation +If you don't have a project initialized, set it up. + +```sh +go mod init example.com/snowflake +``` + Get Gosnowflake source code, if not installed. ```sh diff --git a/arrow_test.go b/arrow_test.go index 1234a6765..cca07603d 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -13,6 +13,27 @@ import ( "time" ) +//A test just to show Snowflake version +func TestCheckVersion(t *testing.T) { + conn := openConn(t) + defer conn.Close() + + rows, err := conn.QueryContext(context.Background(), "SELECT current_version()") + if err != nil { + t.Error(err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatalf("failed to find any row") + } + var s string + if err = rows.Scan(&s); err != nil { + t.Fatal(err) + } + println(s) +} + func TestArrowBigInt(t *testing.T) { conn := openConn(t) defer conn.Close() diff --git a/async.go b/async.go index 65434e4a9..d29b24b12 100644 --- a/async.go +++ b/async.go @@ -105,7 +105,7 @@ func (sr *snowflakeRestful) getAsync( } - sc := &snowflakeConn{rest: sr, cfg: cfg} + sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init()} if respd.Success { if resType == execResultType { res.insertID = -1 diff --git a/auth_test.go b/auth_test.go index 35b80e327..43123d760 100644 --- a/auth_test.go +++ b/auth_test.go @@ -368,26 +368,24 @@ func postAuthCheckExternalBrowserFailed(_ context.Context, _ *snowflakeRestful, } func getDefaultSnowflakeConn() *snowflakeConn { - cfg := Config{ - Account: "a", - User: "u", - Password: "p", - Database: "d", - Schema: "s", - Warehouse: "w", - Role: "r", - Region: "", - Params: make(map[string]*string), - PasscodeInPassword: false, - Passcode: "", - Application: "testapp", - } - sr := &snowflakeRestful{ - TokenAccessor: getSimpleTokenAccessor(), - } sc := &snowflakeConn{ - rest: sr, - cfg: &cfg, + rest: &snowflakeRestful{ + TokenAccessor: getSimpleTokenAccessor(), + }, + cfg: &Config{ + Account: "a", + User: "u", + Password: "p", + Database: "d", + Schema: "s", + Warehouse: "w", + Role: "r", + Region: "", + Params: make(map[string]*string), + PasscodeInPassword: false, + Passcode: "", + Application: "testapp", + }, telemetry: &snowflakeTelemetry{enabled: false}, } return sc diff --git a/chunk_test.go b/chunk_test.go index c9a158602..56b24354b 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -395,253 +395,219 @@ func TestWithStreamDownloader(t *testing.T) { } func TestWithArrowBatches(t *testing.T) { - ctx := WithArrowBatches(context.Background()) - numrows := 3000 // approximately 6 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := WithArrowBatches(sct.sc.ctx) + numrows := 3000 // approximately 6 ArrowBatch objects - pool := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer pool.AssertSize(t, 0) - ctx = WithArrowAllocator(ctx, pool) + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx = WithArrowAllocator(ctx, pool) - query := fmt.Sprintf(selectRandomGenerator, numrows) - rows, err := sc.QueryContext(ctx, query, []driver.NamedValue{}) - if err != nil { - t.Error(err) - } - defer rows.Close() + query := fmt.Sprintf(selectRandomGenerator, numrows) + rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) + defer rows.Close() - // getting result batches - batches, err := rows.(*snowflakeRows).GetArrowBatches() - if err != nil { - t.Error(err) - } - numBatches := len(batches) - maxWorkers := 10 // enough for 3000 rows - type count struct { - m sync.Mutex - recVal int - metaVal int - } - cnt := count{recVal: 0} - var wg sync.WaitGroup - chunks := make(chan int, numBatches) - - // kicking off download workers - each of which will call fetch on a different result batch - for w := 1; w <= maxWorkers; w++ { - wg.Add(1) - go func(wg *sync.WaitGroup, chunks <-chan int) { - defer wg.Done() - - for i := range chunks { - rec, err := batches[i].Fetch() - if err != nil { - t.Error(err) - } - for _, r := range *rec { + // getting result batches + batches, err := rows.(*snowflakeRows).GetArrowBatches() + if err != nil { + t.Error(err) + } + numBatches := len(batches) + maxWorkers := 10 // enough for 3000 rows + type count struct { + m sync.Mutex + recVal int + metaVal int + } + cnt := count{recVal: 0} + var wg sync.WaitGroup + chunks := make(chan int, numBatches) + + // kicking off download workers - each of which will call fetch on a different result batch + for w := 1; w <= maxWorkers; w++ { + wg.Add(1) + go func(wg *sync.WaitGroup, chunks <-chan int) { + defer wg.Done() + + for i := range chunks { + rec, err := batches[i].Fetch() + if err != nil { + t.Error(err) + } + for _, r := range *rec { + cnt.m.Lock() + cnt.recVal += int(r.NumRows()) + cnt.m.Unlock() + r.Release() + } cnt.m.Lock() - cnt.recVal += int(r.NumRows()) + cnt.metaVal += batches[i].rowCount cnt.m.Unlock() - r.Release() } - cnt.m.Lock() - cnt.metaVal += batches[i].rowCount - cnt.m.Unlock() - } - }(&wg, chunks) - } - for j := 0; j < numBatches; j++ { - chunks <- j - } - close(chunks) + }(&wg, chunks) + } + for j := 0; j < numBatches; j++ { + chunks <- j + } + close(chunks) - // wait for workers to finish fetching and check row counts - wg.Wait() - if cnt.recVal != numrows { - t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) - } - if cnt.metaVal != numrows { - t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) - } + // wait for workers to finish fetching and check row counts + wg.Wait() + if cnt.recVal != numrows { + t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) + } + if cnt.metaVal != numrows { + t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) + } + }) } func TestWithArrowBatchesAsync(t *testing.T) { - ctx := WithAsyncMode(context.Background()) - ctx = WithArrowBatches(ctx) - numrows := 50000 // approximately 10 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := WithAsyncMode(sct.sc.ctx) + ctx = WithArrowBatches(ctx) + numrows := 50000 // approximately 10 ArrowBatch objects - pool := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer pool.AssertSize(t, 0) - ctx = WithArrowAllocator(ctx, pool) + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + ctx = WithArrowAllocator(ctx, pool) - query := fmt.Sprintf(selectRandomGenerator, numrows) - rows, err := sc.QueryContext(ctx, query, []driver.NamedValue{}) - if err != nil { - t.Error(err) - } - defer rows.Close() + query := fmt.Sprintf(selectRandomGenerator, numrows) + rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) + defer rows.Close() - // getting result batches - // this will fail if GetArrowBatches() is not a blocking call - batches, err := rows.(*snowflakeRows).GetArrowBatches() - if err != nil { - t.Error(err) - } - numBatches := len(batches) - maxWorkers := 10 - type count struct { - m sync.Mutex - recVal int - metaVal int - } - cnt := count{recVal: 0} - var wg sync.WaitGroup - chunks := make(chan int, numBatches) - - // kicking off download workers - each of which will call fetch on a different result batch - for w := 1; w <= maxWorkers; w++ { - wg.Add(1) - go func(wg *sync.WaitGroup, chunks <-chan int) { - defer wg.Done() - - for i := range chunks { - rec, err := batches[i].Fetch() - if err != nil { - t.Error(err) - } - for _, r := range *rec { + // getting result batches + // this will fail if GetArrowBatches() is not a blocking call + batches, err := rows.(*snowflakeRows).GetArrowBatches() + if err != nil { + t.Error(err) + } + numBatches := len(batches) + maxWorkers := 10 + type count struct { + m sync.Mutex + recVal int + metaVal int + } + cnt := count{recVal: 0} + var wg sync.WaitGroup + chunks := make(chan int, numBatches) + + // kicking off download workers - each of which will call fetch on a different result batch + for w := 1; w <= maxWorkers; w++ { + wg.Add(1) + go func(wg *sync.WaitGroup, chunks <-chan int) { + defer wg.Done() + + for i := range chunks { + rec, err := batches[i].Fetch() + if err != nil { + t.Error(err) + } + for _, r := range *rec { + cnt.m.Lock() + cnt.recVal += int(r.NumRows()) + cnt.m.Unlock() + r.Release() + } cnt.m.Lock() - cnt.recVal += int(r.NumRows()) + cnt.metaVal += batches[i].rowCount cnt.m.Unlock() - r.Release() } - cnt.m.Lock() - cnt.metaVal += batches[i].rowCount - cnt.m.Unlock() - } - }(&wg, chunks) - } - for j := 0; j < numBatches; j++ { - chunks <- j - } - close(chunks) + }(&wg, chunks) + } + for j := 0; j < numBatches; j++ { + chunks <- j + } + close(chunks) - // wait for workers to finish fetching and check row counts - wg.Wait() - if cnt.recVal != numrows { - t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) - } - if cnt.metaVal != numrows { - t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) - } + // wait for workers to finish fetching and check row counts + wg.Wait() + if cnt.recVal != numrows { + t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) + } + if cnt.metaVal != numrows { + t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) + } + }) } func TestQueryArrowStream(t *testing.T) { - ctx := context.Background() - numrows := 50000 // approximately 10 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - query := fmt.Sprintf(selectRandomGenerator, numrows) - loader, err := sc.QueryArrowStream(ctx, query) - if err != nil { - t.Error(err) - } - - if loader.TotalRows() != int64(numrows) { - t.Errorf("total numrows did not match expected, wanted %v, got %v", numrows, loader.TotalRows()) - } - - batches, err := loader.GetBatches() - if err != nil { - t.Error(err) - } - - numBatches := len(batches) - maxWorkers := 8 - chunks := make(chan int, numBatches) - total := int64(0) - meta := int64(0) - - var wg sync.WaitGroup - wg.Add(maxWorkers) + runSnowflakeConnTest(t, func(sct *SCTest) { + numrows := 50000 // approximately 10 ArrowBatch objects - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - - for w := 0; w < maxWorkers; w++ { - go func() { - defer wg.Done() + query := fmt.Sprintf(selectRandomGenerator, numrows) + loader, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) + if err != nil { + t.Error(err) + } - for i := range chunks { - r, err := batches[i].GetStream(ctx) - if err != nil { - t.Error(err) - continue - } - rdr, err := ipc.NewReader(r, ipc.WithAllocator(mem)) - if err != nil { - t.Errorf("Error creating IPC reader for stream %d: %s", i, err) - r.Close() - continue - } + if loader.TotalRows() != int64(numrows) { + t.Errorf("total numrows did not match expected, wanted %v, got %v", numrows, loader.TotalRows()) + } - for rdr.Next() { - rec := rdr.Record() - atomic.AddInt64(&total, rec.NumRows()) - } + batches, err := loader.GetBatches() + if err != nil { + t.Error(err) + } - if rdr.Err() != nil { - t.Error(rdr.Err()) - } - rdr.Release() - if err := r.Close(); err != nil { - t.Error(err) + numBatches := len(batches) + maxWorkers := 8 + chunks := make(chan int, numBatches) + total := int64(0) + meta := int64(0) + + var wg sync.WaitGroup + wg.Add(maxWorkers) + + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + for w := 0; w < maxWorkers; w++ { + go func() { + defer wg.Done() + + for i := range chunks { + r, err := batches[i].GetStream(sct.sc.ctx) + if err != nil { + t.Error(err) + continue + } + rdr, err := ipc.NewReader(r, ipc.WithAllocator(mem)) + if err != nil { + t.Errorf("Error creating IPC reader for stream %d: %s", i, err) + r.Close() + continue + } + + for rdr.Next() { + rec := rdr.Record() + atomic.AddInt64(&total, rec.NumRows()) + } + + if rdr.Err() != nil { + t.Error(rdr.Err()) + } + rdr.Release() + if err := r.Close(); err != nil { + t.Error(err) + } + atomic.AddInt64(&meta, batches[i].NumRows()) } - atomic.AddInt64(&meta, batches[i].NumRows()) - } - }() - } + }() + } - for j := 0; j < numBatches; j++ { - chunks <- j - } - close(chunks) - wg.Wait() + for j := 0; j < numBatches; j++ { + chunks <- j + } + close(chunks) + wg.Wait() - if total != int64(numrows) { - t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, total) - } - if meta != int64(numrows) { - t.Errorf("number of rows from batch metadata didn't match. expected: %v, got: %v", numrows, total) - } + if total != int64(numrows) { + t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, total) + } + if meta != int64(numrows) { + t.Errorf("number of rows from batch metadata didn't match. expected: %v, got: %v", numrows, total) + } + }) } diff --git a/connection.go b/connection.go index aa26e8532..4638175e9 100644 --- a/connection.go +++ b/connection.go @@ -37,9 +37,10 @@ const ( ) const ( - statementTypeIDMulti = int64(0x1000) + statementTypeIDSelect = int64(0x1000) statementTypeIDDml = int64(0x3000) statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500) + statementTypeIDMultistatement = int64(0xA000) ) const ( @@ -60,20 +61,18 @@ const ( const privateLinkSuffix = "privatelink.snowflakecomputing.com" type snowflakeConn struct { - ctx context.Context - cfg *Config - rest *snowflakeRestful - SequenceCounter uint64 - QueryID string - SQLState string - telemetry *snowflakeTelemetry - internal InternalClient + ctx context.Context + cfg *Config + rest *snowflakeRestful + SequenceCounter uint64 + telemetry *snowflakeTelemetry + internal InternalClient + queryContextCache *queryContextCache } var ( queryIDPattern = `[\w\-_]+` queryIDRegexp = regexp.MustCompile(queryIDPattern) - errMutex = &sync.Mutex{} ) func (sc *snowflakeConn) exec( @@ -87,6 +86,10 @@ func (sc *snowflakeConn) exec( var err error counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter + queryContext, err := buildQueryContext(sc.queryContextCache) + if err != nil { + logger.Errorf("error while building query context: %v", err) + } req := execRequest{ SQLText: query, AsyncExec: noResult, @@ -94,6 +97,7 @@ func (sc *snowflakeConn) exec( IsInternal: isInternal, DescribeOnly: describeOnly, SequenceID: counter, + QueryContext: queryContext, } if key := ctx.Value(multiStatementCount); key != nil { req.Parameters[string(multiStatementCount)] = key @@ -139,12 +143,19 @@ func (sc *snowflakeConn) exec( } logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code) if !data.Success { - errMutex.Lock() - defer errMutex.Unlock() err = (populateErrorFields(code, data)).exceptionTelemetry(sc) return nil, err } + if !sc.cfg.DisableQueryContextCache && data.Data.QueryContext != nil { + queryContext, err := extractQueryContext(data) + if err != nil { + logger.Errorf("error while decoding query context: ", err) + } else { + sc.queryContextCache.add(sc, queryContext.Entries...) + } + } + // handle PUT/GET commands if isFileTransfer(query) { data, err = sc.processFileTransfer(ctx, data, query, isInternal) @@ -158,12 +169,37 @@ func (sc *snowflakeConn) exec( sc.cfg.Schema = data.Data.FinalSchemaName sc.cfg.Role = data.Data.FinalRoleName sc.cfg.Warehouse = data.Data.FinalWarehouseName - sc.QueryID = data.Data.QueryID - sc.SQLState = data.Data.SQLState sc.populateSessionParameters(data.Data.Parameters) return data, err } +func extractQueryContext(data *execResponse) (queryContext, error) { + var queryContext queryContext + err := json.Unmarshal(data.Data.QueryContext, &queryContext) + return queryContext, err +} + +func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) { + rqc := requestQueryContext{} + if qcc == nil || len(qcc.entries) == 0 { + logger.Debugf("empty qcc") + return rqc, nil + } + for _, qce := range qcc.entries { + contextData := contextData{} + if qce.Context == "" { + contextData.Base64Data = qce.Context + } + rqc.Entries = append(rqc.Entries, requestQueryContextEntry{ + ID: qce.ID, + Priority: qce.Priority, + Timestamp: qce.Timestamp, + Context: contextData, + }) + } + return rqc, nil +} + func (sc *snowflakeConn) Begin() (driver.Tx, error) { return sc.BeginTx(sc.ctx, driver.TxOptions{}) } @@ -282,7 +318,7 @@ func (sc *snowflakeConn) ExecContext( return &snowflakeResult{ affectedRows: updatedRows, insertID: -1, - queryID: sc.QueryID, + queryID: data.Data.QueryID, }, nil // last insert id is not supported by Snowflake } else if isMultiStmt(&data.Data) { return sc.handleMultiExec(ctx, data.Data) @@ -353,7 +389,7 @@ func (sc *snowflakeConn) queryContextInternal( rows := new(snowflakeRows) rows.sc = sc - rows.queryID = sc.QueryID + rows.queryID = data.Data.QueryID if isMultiStmt(&data.Data) { // handleMultiQuery is responsible to fill rows with childResults @@ -683,9 +719,10 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { sc := &snowflakeConn{ - SequenceCounter: 0, - ctx: ctx, - cfg: &config, + SequenceCounter: 0, + ctx: ctx, + cfg: &config, + queryContextCache: (&queryContextCache{}).init(), } var st http.RoundTripper = SnowflakeTransport if sc.cfg.Transporter == nil { diff --git a/connection_test.go b/connection_test.go index 3f77c66e7..0aa2e90da 100644 --- a/connection_test.go +++ b/connection_test.go @@ -90,8 +90,9 @@ func TestExecWithEmptyRequestID(t *testing.T) { } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + queryContextCache: (&queryContextCache{}).init(), } if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil); err != nil { @@ -161,8 +162,9 @@ func TestExecWithSpecificRequestID(t *testing.T) { } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + queryContextCache: (&queryContextCache{}).init(), } if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ false /* describeOnly */, nil); err != nil { @@ -181,8 +183,9 @@ func TestServiceName(t *testing.T) { } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + queryContextCache: (&queryContextCache{}).init(), } expectServiceName := serviceNameStub @@ -219,9 +222,10 @@ func TestCloseIgnoreSessionGone(t *testing.T) { FuncCloseSession: closeSessionMock, } sc := &snowflakeConn{ - cfg: &Config{Params: map[string]*string{}}, - rest: sr, - telemetry: testTelemetry, + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + telemetry: testTelemetry, + queryContextCache: (&queryContextCache{}).init(), } if sc.Close() != nil { @@ -324,8 +328,8 @@ func fetchResultByQueryID( } if _, err = sc.Exec(`create or replace table ut_conn(c1 number, c2 string) - as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 from - table(generator(rowcount => 100)))`, nil); err != nil { + as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 + from table(generator(rowcount => 100)))`, nil); err != nil { t.Fatalf("err: %v", err) } @@ -394,70 +398,44 @@ func TestPrivateLink(t *testing.T) { } func TestGetQueryStatus(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - if _, err = sc.Exec(`create or replace table ut_conn(c1 number, c2 string) - as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 from - table(generator(rowcount => 100)))`, nil); err != nil { - t.Error(err) - } - - rows, err := sc.QueryContext(ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) - if err != nil { - t.Error(err) - } - qid := rows.(SnowflakeResult).GetQueryID() - - // use conn as type holder for SnowflakeConnection placeholder - var conn interface{} = sc - qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(ctx, qid) - if err != nil { - t.Errorf("failed to get query status err = %s", err.Error()) - return - } - if qStatus == nil { - t.Error("there was no query status returned") - return - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sct.mustExec(`create or replace table ut_conn(c1 number, c2 string) + as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 + from table(generator(rowcount => 100)))`, + nil) + + rows := sct.mustQueryContext(sct.sc.ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) + qid := rows.(SnowflakeResult).GetQueryID() + + // use conn as type holder for SnowflakeConnection placeholder + var conn interface{} = sct.sc + qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(sct.sc.ctx, qid) + if err != nil { + t.Errorf("failed to get query status err = %s", err.Error()) + return + } + if qStatus == nil { + t.Error("there was no query status returned") + return + } - if qStatus.ErrorCode != "" || qStatus.ScanBytes != 2048 || qStatus.ProducedRows != 10 { - t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v", - qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows) - return - } + if qStatus.ErrorCode != "" || qStatus.ScanBytes != 2048 || qStatus.ProducedRows != 10 { + t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v", + qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows) + return + } + }) } func TestGetInvalidQueryStatus(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - sc.rest.RequestTimeout = 1 * time.Second + runSnowflakeConnTest(t, func(sct *SCTest) { + sct.sc.rest.RequestTimeout = 1 * time.Second - qStatus, err := sc.checkQueryStatus(ctx, "1234") - if err == nil || qStatus != nil { - t.Error("expected an error") - } + qStatus, err := sct.sc.checkQueryStatus(sct.sc.ctx, "1234") + if err == nil || qStatus != nil { + t.Error("expected an error") + } + }) } func TestExecWithServerSideError(t *testing.T) { @@ -487,11 +465,12 @@ func TestExecWithServerSideError(t *testing.T) { t.Error("expected a server side error") } sfe := err.(*SnowflakeError) + errUnknownError := errUnknownError() if sfe.Number != -1 || sfe.SQLState != "-1" || sfe.QueryID != "-1" { - t.Errorf("incorrect snowflake error. expected: %v, got: %v", ErrUnknownError, *sfe) + t.Errorf("incorrect snowflake error. expected: %v, got: %v", errUnknownError, *sfe) } if !strings.Contains(sfe.Message, "an unknown server side error occurred") { - t.Errorf("incorrect message. expected: %v, got: %v", ErrUnknownError.Message, sfe.Message) + t.Errorf("incorrect message. expected: %v, got: %v", errUnknownError.Message, sfe.Message) } } @@ -545,157 +524,119 @@ func postQueryFail(_ context.Context, _ *snowflakeRestful, _ *url.Values, header }, errors.New("failed to get query response") } -func TestQueryArrowStreamError(t *testing.T) { - ctx := context.Background() - numrows := 50000 // approximately 10 ArrowBatch objects - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - query := fmt.Sprintf(selectRandomGenerator, numrows) - sr := &snowflakeRestful{ - FuncPostQuery: postQueryTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 10, - } - sc.rest = sr - _, err = sc.QueryArrowStream(ctx, query) - if err == nil { - t.Error("should have raised an error") +func TestErrorReportingOnConcurrentFails(t *testing.T) { + db := openDB(t) + defer db.Close() + var wg sync.WaitGroup + n := 5 + wg.Add(3 * n) + for i := 0; i < n; i++ { + go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_ABC", "TABLE_ABC", t, &wg) + go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_DEF", "TABLE_DEF", t, &wg) + go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_GHI", "TABLE_GHI", t, &wg) } + wg.Wait() +} - sc.rest.FuncPostQuery = postQueryFail - _, err = sc.QueryArrowStream(ctx, query) - if err == nil { - t.Error("should have raised an error") - } - _, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) +func executeQueryAndConfirmMessage(db *sql.DB, query string, expectedErrorTable string, t *testing.T, wg *sync.WaitGroup) { + defer wg.Done() + _, err := db.Exec(query) + message := err.(*SnowflakeError).Message + if !strings.Contains(message, expectedErrorTable) { + t.Errorf("QueryID: %s, Message %s ###### Expected error message table name: %s", + err.(*SnowflakeError).QueryID, err.(*SnowflakeError).Message, expectedErrorTable) } } -func TestExecContextError(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } +func TestQueryArrowStreamError(t *testing.T) { + runSnowflakeConnTest(t, func(sct *SCTest) { + numrows := 50000 // approximately 10 ArrowBatch objects + query := fmt.Sprintf(selectRandomGenerator, numrows) + sct.sc.rest = &snowflakeRestful{ + FuncPostQuery: postQueryTest, + FuncCloseSession: closeSessionMock, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 10, + } + _, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) + if err == nil { + t.Error("should have raised an error") + } - sr := &snowflakeRestful{ - FuncPostQuery: postQueryTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 10, - } + sct.sc.rest.FuncPostQuery = postQueryFail + _, err = sct.sc.QueryArrowStream(sct.sc.ctx, query) + if err == nil { + t.Error("should have raised an error") + } + _, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + }) +} - sc.rest = sr +func TestExecContextError(t *testing.T) { + runSnowflakeConnTest(t, func(sct *SCTest) { + sct.sc.rest = &snowflakeRestful{ + FuncPostQuery: postQueryTest, + FuncCloseSession: closeSessionMock, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 10, + } - _, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } + _, err := sct.sc.ExecContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } - sc.rest.FuncPostQuery = postQueryFail - _, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } - _, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } + sct.sc.rest.FuncPostQuery = postQueryFail + _, err = sct.sc.ExecContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } + }) } func TestQueryContextError(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - sr := &snowflakeRestful{ - FuncPostQuery: postQueryTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 10, - } - - sc.rest = sr - - _, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sct.sc.rest = &snowflakeRestful{ + FuncPostQuery: postQueryTest, + FuncCloseSession: closeSessionMock, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 10, + } + _, err := sct.sc.QueryContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } - sc.rest.FuncPostQuery = postQueryFail - _, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{}) - if err == nil { - t.Fatalf("should have raised an error") - } - _, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } + sct.sc.rest.FuncPostQuery = postQueryFail + _, err = sct.sc.QueryContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) + if err == nil { + t.Fatalf("should have raised an error") + } + _, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + }) } func TestPrepareQuery(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - _, err = sc.Prepare("SELECT 1") + runSnowflakeConnTest(t, func(sct *SCTest) { + _, err := sct.sc.Prepare("SELECT 1") - if err != nil { - t.Fatalf("failed to prepare query. err: %v", err) - } - sc.Close() + if err != nil { + t.Fatalf("failed to prepare query. err: %v", err) + } + }) } func TestBeginCreatesTransaction(t *testing.T) { - ctx := context.Background() - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - tx, _ := sc.Begin() - if tx == nil { - t.Fatal("should have created a transaction with connection") - } - sc.Close() + runSnowflakeConnTest(t, func(sct *SCTest) { + tx, _ := sct.sc.Begin() + if tx == nil { + t.Fatal("should have created a transaction with connection") + } + }) } diff --git a/connection_util.go b/connection_util.go index 849a8a398..a1fc5a7f2 100644 --- a/connection_util.go +++ b/connection_util.go @@ -39,7 +39,7 @@ func (sc *snowflakeConn) stopHeartBeat() { if sc.cfg != nil && !sc.isClientSessionKeepAliveEnabled() { return } - if sc.rest != nil { + if sc.rest != nil && sc.rest.HeartBeat != nil { sc.rest.HeartBeat.stop() } } @@ -213,8 +213,8 @@ func updateRows(data execResponseData) (int64, error) { // Note that the statement type code is also equivalent to type INSERT, so an // additional check of the name is required func isMultiStmt(data *execResponseData) bool { - return data.StatementTypeID == statementTypeIDMulti && - data.RowType[0].Name == "multiple statement execution" + var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution" + return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement } func getResumeQueryID(ctx context.Context) (string, error) { diff --git a/converter_test.go b/converter_test.go index e48b3e503..6a43404f2 100644 --- a/converter_test.go +++ b/converter_test.go @@ -1242,125 +1242,90 @@ func TestArrowToRecord(t *testing.T) { } func TestTimestampLTZLocation(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - src := "1549491451.123456789" - var dest driver.Value - loc, _ := time.LoadLocation(PSTLocation) - if err = stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, loc); err != nil { - t.Errorf("unexpected error: %v", err) - } - ts, ok := dest.(time.Time) - if !ok { - t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) - } - if ts.Location() != loc { - t.Errorf("expected location to be %v, got '%v'", loc, ts.Location()) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + src := "1549491451.123456789" + var dest driver.Value + loc, _ := time.LoadLocation(PSTLocation) + if err := stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, loc); err != nil { + t.Errorf("unexpected error: %v", err) + } + ts, ok := dest.(time.Time) + if !ok { + t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) + } + if ts.Location() != loc { + t.Errorf("expected location to be %v, got '%v'", loc, ts.Location()) + } - if err = stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, nil); err != nil { - t.Errorf("unexpected error: %v", err) - } - ts, ok = dest.(time.Time) - if !ok { - t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) - } - if ts.Location() != time.Local { - t.Errorf("expected location to be local, got '%v'", ts.Location()) - } + if err := stringToValue(&dest, execResponseRowType{Type: "timestamp_ltz"}, &src, nil); err != nil { + t.Errorf("unexpected error: %v", err) + } + ts, ok = dest.(time.Time) + if !ok { + t.Errorf("expected type: 'time.Time', got '%v'", reflect.TypeOf(dest)) + } + if ts.Location() != time.Local { + t.Errorf("expected location to be local, got '%v'", ts.Location()) + } + }) } func TestSmallTimestampBinding(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - timeValue, err := time.Parse("2006-01-02 15:04:05", "1600-10-10 10:10:10") - if err != nil { - t.Fatalf("failed to parse time: %v", err) - } - parameters := []driver.NamedValue{ - {Ordinal: 1, Value: DataTypeTimestampNtz}, - {Ordinal: 2, Value: timeValue}, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + timeValue, err := time.Parse("2006-01-02 15:04:05", "1600-10-10 10:10:10") + if err != nil { + t.Fatalf("failed to parse time: %v", err) + } + parameters := []driver.NamedValue{ + {Ordinal: 1, Value: DataTypeTimestampNtz}, + {Ordinal: 2, Value: timeValue}, + } - rows, err := sc.QueryContext(ctx, "SELECT ?", parameters) - if err != nil { - t.Fatalf("failed to run query: %v", err) - } - defer rows.Close() + rows := sct.mustQueryContext(ctx, "SELECT ?", parameters) + defer rows.Close() - scanValues := make([]driver.Value, 1) - for { - if err := rows.Next(scanValues); err == io.EOF { - break - } else if err != nil { - t.Fatalf("failed to run query: %v", err) - } - if scanValues[0] != timeValue { - t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + scanValues := make([]driver.Value, 1) + for { + if err := rows.Next(scanValues); err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to run query: %v", err) + } + if scanValues[0] != timeValue { + t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + } } - } + }) } func TestLargeTimestampBinding(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - ctx := context.Background() - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - timeValue, err := time.Parse("2006-01-02 15:04:05", "9000-10-10 10:10:10") - if err != nil { - t.Fatalf("failed to parse time: %v", err) - } - parameters := []driver.NamedValue{ - {Ordinal: 1, Value: DataTypeTimestampNtz}, - {Ordinal: 2, Value: timeValue}, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + ctx := context.Background() + timeValue, err := time.Parse("2006-01-02 15:04:05", "9000-10-10 10:10:10") + if err != nil { + t.Fatalf("failed to parse time: %v", err) + } + parameters := []driver.NamedValue{ + {Ordinal: 1, Value: DataTypeTimestampNtz}, + {Ordinal: 2, Value: timeValue}, + } - rows, err := sc.QueryContext(ctx, "SELECT ?", parameters) - if err != nil { - t.Fatalf("failed to run query: %v", err) - } - defer rows.Close() + rows := sct.mustQueryContext(ctx, "SELECT ?", parameters) + defer rows.Close() - scanValues := make([]driver.Value, 1) - for { - if err := rows.Next(scanValues); err == io.EOF { - break - } else if err != nil { - t.Fatalf("failed to run query: %v", err) - } - if scanValues[0] != timeValue { - t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + scanValues := make([]driver.Value, 1) + for { + if err := rows.Next(scanValues); err == io.EOF { + break + } else if err != nil { + t.Fatalf("failed to run query: %v", err) + } + if scanValues[0] != timeValue { + t.Fatalf("unexpected result. expected: %v, got: %v", timeValue, scanValues[0]) + } } - } + }) } func TestTimeTypeValueToString(t *testing.T) { diff --git a/datatype.go b/datatype.go index 2c44bebd4..73fb91499 100644 --- a/datatype.go +++ b/datatype.go @@ -32,27 +32,48 @@ const ( unSupportedType ) -var snowflakeTypes = [...]string{"FIXED", "REAL", "TEXT", "DATE", "VARIANT", - "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "TIMESTAMP_TZ", "OBJECT", "ARRAY", - "BINARY", "TIME", "BOOLEAN", "NULL", "SLICE", "CHANGE_TYPE", "NOT_SUPPORTED"} +var snowflakeToDriverType = map[string]snowflakeType{ + "FIXED": fixedType, + "REAL": realType, + "TEXT": textType, + "DATE": dateType, + "VARIANT": variantType, + "TIMESTAMP_LTZ": timestampLtzType, + "TIMESTAMP_NTZ": timestampNtzType, + "TIMESTAMP_TZ": timestampTzType, + "OBJECT": objectType, + "ARRAY": arrayType, + "BINARY": binaryType, + "TIME": timeType, + "BOOLEAN": booleanType, + "NULL": nullType, + "SLICE": sliceType, + "CHANGE_TYPE": changeType, + "NOT_SUPPORTED": unSupportedType} -func (st snowflakeType) String() string { - return snowflakeTypes[st] +var driverTypeToSnowflake = invertMap(snowflakeToDriverType) + +func invertMap(m map[string]snowflakeType) map[snowflakeType]string { + inv := make(map[snowflakeType]string) + for k, v := range m { + if _, ok := inv[v]; ok { + panic("failed to create driverTypeToSnowflake map due to duplicated values") + } + inv[v] = k + } + return inv } func (st snowflakeType) Byte() byte { return byte(st) } +func (st snowflakeType) String() string { + return driverTypeToSnowflake[st] +} + func getSnowflakeType(typ string) snowflakeType { - for i, sft := range snowflakeTypes { - if sft == typ { - return snowflakeType(i) - } else if snowflakeType(i) == nullType { - break - } - } - return nullType + return snowflakeToDriverType[typ] } var ( diff --git a/doc.go b/doc.go index fd496f67c..93def4075 100644 --- a/doc.go +++ b/doc.go @@ -118,6 +118,9 @@ The following connection parameters are supported: - tracing: Specifies the logging level to be used. Set to error by default. Valid values are trace, debug, info, print, warning, error, fatal, panic. + - disableQueryContextCache: disables parsing of query context returned from server and resending it to server as well. + Default value is false. + All other parameters are interpreted as session parameters (https://docs.snowflake.com/en/sql-reference/parameters.html). For example, the TIMESTAMP_OUTPUT_FORMAT session parameter can be set by adding: @@ -168,6 +171,36 @@ in place of the default randomized request ID. For example: ctxWithID := WithRequestID(ctx, requestID) rows, err := db.QueryContext(ctxWithID, query) +# Last query ID + +If you need query ID for your query you have to use raw connection. + +For queries: +``` + + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") + rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + rows.(SnowflakeRows).GetQueryID() + stmt.(SnowflakeStmt).GetQueryID() + return nil + } + +``` + +For execs: +``` + + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") + result, err := stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + result.(SnowflakeResult).GetQueryID() + stmt.(SnowflakeStmt).GetQueryID() + return nil + } + +``` + # Canceling Query by CtrlC From 0.5.0, a signal handling responsibility has moved to the applications. If you want to cancel a diff --git a/driver_test.go b/driver_test.go index 634a53145..1d6662ebe 100644 --- a/driver_test.go +++ b/driver_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/rsa" "database/sql" + "database/sql/driver" "flag" "fmt" "net/http" @@ -320,6 +321,49 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { return stmt } +type SCTest struct { + *testing.T + sc *snowflakeConn +} + +func (sct *SCTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) +} + +func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result { + result, err := sct.sc.Exec(query, args) + if err != nil { + sct.fail("exec", query, err) + } + return result +} +func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows { + rows, err := sct.sc.Query(query, args) + if err != nil { + sct.fail("query", query, err) + } + return rows +} + +func (sct *SCTest) mustQueryContext(ctx context.Context, query string, args []driver.NamedValue) driver.Rows { + rows, err := sct.sc.QueryContext(ctx, query, args) + if err != nil { + sct.fail("QueryContext", query, err) + } + return rows +} + +func (sct *SCTest) mustExecContext(ctx context.Context, query string, args []driver.NamedValue) driver.Result { + result, err := sct.sc.ExecContext(ctx, query, args) + if err != nil { + sct.fail("ExecContext", query, err) + } + return result +} + func runDBTest(t *testing.T, test func(dbt *DBTest)) { conn := openConn(t) defer conn.Close() @@ -328,6 +372,25 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) { test(dbt) } +func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { + config, err := ParseDSN(dsn) + if err != nil { + t.Error(err) + } + sc, err := buildSnowflakeConn(context.Background(), *config) + if err != nil { + t.Fatal(err) + } + defer sc.Close() + if err = authenticateWithConfig(sc); err != nil { + t.Fatal(err) + } + + sct := &SCTest{t, sc} + + test(sct) +} + func runningOnAWS() bool { return os.Getenv("CLOUD_PROVIDER") == "AWS" } diff --git a/dsn.go b/dsn.go index cdfd805e5..d5fa2ab73 100644 --- a/dsn.go +++ b/dsn.go @@ -97,6 +97,8 @@ type Config struct { IDToken string // Internally used to cache the Id Token for external browser ClientRequestMfaToken ConfigBool // When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. ClientStoreTemporaryCredential ConfigBool // When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux. + + DisableQueryContextCache bool // Should HTAP query context cache be disabled } // Validate enables testing if config is correct. @@ -139,7 +141,7 @@ func DSN(cfg *Config) (dsn string, err error) { posDot := strings.Index(cfg.Account, ".") if posDot > 0 { if cfg.Region != "" { - return "", ErrInvalidRegion + return "", errInvalidRegion() } cfg.Region = cfg.Account[posDot+1:] cfg.Account = cfg.Account[:posDot] @@ -230,6 +232,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.TmpDirPath != "" { params.Add("tmpDirPath", cfg.TmpDirPath) } + if cfg.DisableQueryContextCache { + params.Add("disableQueryContextCache", "true") + } params.Add("ocspFailOpen", strconv.FormatBool(cfg.OCSPFailOpen != OCSPFailOpenFalse)) @@ -403,15 +408,15 @@ func fillMissingConfigParameters(cfg *Config) error { } } if strings.Trim(cfg.Account, " ") == "" { - return ErrEmptyAccount + return errEmptyAccount() } if authRequiresUser(cfg) && strings.TrimSpace(cfg.User) == "" { - return ErrEmptyUsername + return errEmptyUsername() } if authRequiresPassword(cfg) && strings.TrimSpace(cfg.Password) == "" { - return ErrEmptyPassword + return errEmptyPassword() } if strings.Trim(cfg.Protocol, " ") == "" { cfg.Protocol = "https" @@ -702,6 +707,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { cfg.Tracing = value case "tmpDirPath": cfg.TmpDirPath = value + case "disableQueryContextCache": + var b bool + b, err = strconv.ParseBool(value) + if err != nil { + return + } + cfg.DisableQueryContextCache = b default: if cfg.Params == nil { cfg.Params = make(map[string]*string) diff --git a/dsn_test.go b/dsn_test.go index 21ba433f7..0086e2cf5 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -25,7 +25,6 @@ type tcParseDSN struct { } func TestParseDSN(t *testing.T) { - privKeyPKCS8 := generatePKCS8StringSupress(testPrivKey) privKeyPKCS1 := generatePKCS1String(testPrivKey) testcases := []tcParseDSN{ @@ -201,7 +200,7 @@ func TestParseDSN(t *testing.T) { ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyPassword, + err: errEmptyPassword(), }, { dsn: "@host:123/db/schema?account=ac&protocol=http", @@ -216,7 +215,7 @@ func TestParseDSN(t *testing.T) { ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyUsername, + err: errEmptyUsername(), }, { dsn: "user:p@host:123/db/schema?protocol=http", @@ -231,7 +230,7 @@ func TestParseDSN(t *testing.T) { ExternalBrowserTimeout: defaultExternalBrowserTimeout, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "u:p@a.snowflakecomputing.com/db/pa?account=a&protocol=https&role=r&timezone=UTC&warehouse=w", @@ -415,12 +414,12 @@ func TestParseDSN(t *testing.T) { { dsn: "u:u@/+/+?account=+&=0", config: &Config{}, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "u:u@/+/+?account=+&=+&=+", config: &Config{}, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { dsn: "user%40%2F1:p%3A%40s@/db%2F?account=ac", @@ -573,9 +572,10 @@ func TestParseDSN(t *testing.T) { Account: "a", User: "u", Password: "p", Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, - ClientTimeout: 300 * time.Second, - JWTClientTimeout: 45 * time.Second, - ExternalBrowserTimeout: defaultExternalBrowserTimeout, + ClientTimeout: 300 * time.Second, + JWTClientTimeout: 45 * time.Second, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + DisableQueryContextCache: false, }, ocspMode: ocspModeFailOpen, err: nil, @@ -594,6 +594,20 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, + { + dsn: "u:p@a.r.c.snowflakecomputing.com/db/s?account=a.r.c&disableQueryContextCache=true", + config: &Config{ + Account: "a", User: "u", Password: "p", + Protocol: "https", Host: "a.r.c.snowflakecomputing.com", Port: 443, + Database: "db", Schema: "s", ValidateDefaultParameters: ConfigBoolTrue, OCSPFailOpen: OCSPFailOpenTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + DisableQueryContextCache: true, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { @@ -630,7 +644,7 @@ func TestParseDSN(t *testing.T) { Authenticator: at, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyUsername, + err: errEmptyUsername(), }) } @@ -649,7 +663,7 @@ func TestParseDSN(t *testing.T) { Authenticator: at, }, ocspMode: ocspModeFailOpen, - err: ErrEmptyPassword, + err: errEmptyPassword(), }) } @@ -744,6 +758,9 @@ func TestParseDSN(t *testing.T) { if test.config.TmpDirPath != cfg.TmpDirPath { t.Fatalf("%v: Failed to match TmpDirPatch. expected: %v, got: %v", i, test.config.TmpDirPath, cfg.TmpDirPath) } + if test.config.DisableQueryContextCache != cfg.DisableQueryContextCache { + t.Fatalf("%v: Failed to match DisableQueryContextCache. expected: %v, got: %v", i, test.config.DisableQueryContextCache, cfg.DisableQueryContextCache) + } case test.err != nil: driverErrE, okE := test.err.(*SnowflakeError) driverErrG, okG := err.(*SnowflakeError) @@ -775,7 +792,6 @@ type tcDSN struct { func TestDSN(t *testing.T) { tmfmt := "MM-DD-YYYY" - testcases := []tcDSN{ { cfg: &Config{ @@ -809,7 +825,7 @@ func TestDSN(t *testing.T) { Account: "a-aofnadsf.global", Region: "r", }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -853,7 +869,7 @@ func TestDSN(t *testing.T) { Password: "p", Account: "a", }, - err: ErrEmptyUsername, + err: errEmptyUsername(), }, { cfg: &Config{ @@ -861,7 +877,7 @@ func TestDSN(t *testing.T) { Password: "", Account: "a", }, - err: ErrEmptyPassword, + err: errEmptyPassword(), }, { cfg: &Config{ @@ -869,7 +885,7 @@ func TestDSN(t *testing.T) { Password: "p", Account: "", }, - err: ErrEmptyAccount, + err: errEmptyAccount(), }, { cfg: &Config{ @@ -895,7 +911,7 @@ func TestDSN(t *testing.T) { Account: "a.e", Region: "r", }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -1039,7 +1055,7 @@ func TestDSN(t *testing.T) { Account: "a.b.c", Region: "r", }, - err: ErrInvalidRegion, + err: errInvalidRegion(), }, { cfg: &Config{ @@ -1135,6 +1151,15 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?ocspFailOpen=true®ion=b.c&tmpDirPath=%2Ftmp&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + DisableQueryContextCache: true, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?disableQueryContextCache=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, } for _, test := range testcases { t.Run(test.dsn, func(t *testing.T) { diff --git a/errors.go b/errors.go index 459850751..a4104f66c 100644 --- a/errors.go +++ b/errors.go @@ -82,7 +82,7 @@ func (se *SnowflakeError) exceptionTelemetry(sc *snowflakeConn) *SnowflakeError // return populated error fields replacing the default response func populateErrorFields(code int, data *execResponse) *SnowflakeError { - err := ErrUnknownError + err := errUnknownError() if code != -1 { err.Number = code } @@ -290,32 +290,44 @@ const ( errMsgInvalidPadding = "invalid padding on input" ) -var ( - // ErrEmptyAccount is returned if a DNS doesn't include account parameter. - ErrEmptyAccount = &SnowflakeError{ +// Returned if a DNS doesn't include account parameter. +func errEmptyAccount() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyAccountCode, Message: "account is empty", } - // ErrEmptyUsername is returned if a DNS doesn't include user parameter. - ErrEmptyUsername = &SnowflakeError{ +} + +// Returned if a DNS doesn't include user parameter. +func errEmptyUsername() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyUsernameCode, Message: "user is empty", } - // ErrEmptyPassword is returned if a DNS doesn't include password parameter. - ErrEmptyPassword = &SnowflakeError{ +} + +// Returned if a DNS doesn't include password parameter. +func errEmptyPassword() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeEmptyPasswordCode, - Message: "password is empty"} + Message: "password is empty", + } +} - // ErrInvalidRegion is returned if a DSN's implicit region from account parameter and explicit region parameter conflict. - ErrInvalidRegion = &SnowflakeError{ +// Returned if a DSN's implicit region from account parameter and explicit region parameter conflict. +func errInvalidRegion() *SnowflakeError { + return &SnowflakeError{ Number: ErrCodeRegionOverlap, - Message: "two regions specified"} + Message: "two regions specified", + } +} - // ErrUnknownError is returned if the server side returns an error without meaningful message. - ErrUnknownError = &SnowflakeError{ +// Returned if the server side returns an error without meaningful message. +func errUnknownError() *SnowflakeError { + return &SnowflakeError{ Number: -1, SQLState: "-1", Message: "an unknown server side error occurred", QueryID: "-1", } -) +} diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index 633bb5019..163fe3ce1 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -28,33 +28,24 @@ func TestGetBucketAccelerateConfiguration(t *testing.T) { if runningOnGithubAction() { t.Skip("Should be run against an account in AWS EU North1 region.") } - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } - if err = sfa.transferAccelerateConfig(); err != nil { - var ae smithy.APIError - if errors.As(err, &ae) { - if ae.ErrorCode() == "MethodNotAllowed" { - t.Fatalf("should have ignored 405 error: %v", err) + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } + if err := sfa.transferAccelerateConfig(); err != nil { + var ae smithy.APIError + if errors.As(err, &ae) { + if ae.ErrorCode() == "MethodNotAllowed" { + t.Fatalf("should have ignored 405 error: %v", err) + } } } - } + }) } func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { @@ -88,455 +79,369 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { }) } func TestUnitGetLocalFilePathFromCommand(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } - testcases := []tcFilePath{ - {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, - {"PUT 'file:///tmp/my_data_file.txt' @~ overwrite=true", "/tmp/my_data_file.txt"}, - {"PUT file:///tmp/sub_dir/my_data_file.txt\n @~ overwrite=true", "/tmp/sub_dir/my_data_file.txt"}, - {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, - {"", ""}, - {"PUT 'file2:///tmp/my_data_file.txt' @~ overwrite=true", ""}, - } - for _, test := range testcases { - t.Run(test.command, func(t *testing.T) { - path := sfa.getLocalFilePathFromCommand(test.command) - if path != test.path { - t.Fatalf("unexpected file path. expected: %v, but got: %v", test.path, path) - } - }) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } + testcases := []tcFilePath{ + {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, + {"PUT 'file:///tmp/my_data_file.txt' @~ overwrite=true", "/tmp/my_data_file.txt"}, + {"PUT file:///tmp/sub_dir/my_data_file.txt\n @~ overwrite=true", "/tmp/sub_dir/my_data_file.txt"}, + {"PUT file:///tmp/my_data_file.txt @~ overwrite=true", "/tmp/my_data_file.txt"}, + {"", ""}, + {"PUT 'file2:///tmp/my_data_file.txt' @~ overwrite=true", ""}, + } + for _, test := range testcases { + t.Run(test.command, func(t *testing.T) { + path := sfa.getLocalFilePathFromCommand(test.command) + if path != test.path { + t.Fatalf("unexpected file path. expected: %v, but got: %v", test.path, path) + } + }) + } + }) } func TestUnitProcessFileCompressionType(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - } - testcases := []struct { - srcCompression string - }{ - {"none"}, - {"auto_detect"}, - {"gzip"}, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + } + testcases := []struct { + srcCompression string + }{ + {"none"}, + {"auto_detect"}, + {"gzip"}, + } - for _, test := range testcases { - t.Run(test.srcCompression, func(t *testing.T) { - sfa.srcCompression = test.srcCompression - err = sfa.processFileCompressionType() - if err != nil { - t.Fatalf("failed to process file compression") - } - }) - } + for _, test := range testcases { + t.Run(test.srcCompression, func(t *testing.T) { + sfa.srcCompression = test.srcCompression + err := sfa.processFileCompressionType() + if err != nil { + t.Fatalf("failed to process file compression") + } + }) + } - // test invalid compression type error - sfa.srcCompression = "gz" - data := &execResponseData{ - SQLState: "S00087", - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - } - sfa.data = data - err = sfa.processFileCompressionType() - if err == nil { - t.Fatal("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCompressionNotSupported { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCompressionNotSupported, driverErr.Number) - } + // test invalid compression type error + sfa.srcCompression = "gz" + data := &execResponseData{ + SQLState: "S00087", + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + } + sfa.data = data + err := sfa.processFileCompressionType() + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrCompressionNotSupported { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCompressionNotSupported, driverErr.Number) + } + }) } func TestParseCommandWithInvalidStageLocation(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } - err = sfa.parseCommand() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrInvalidStageLocation { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageLocation, driverErr.Number) - } + err := sfa.parseCommand() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrInvalidStageLocation { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageLocation, driverErr.Number) + } + }) } func TestParseCommandEncryptionMaterialMismatchError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - mockEncMaterial1 := snowflakeFileEncryption{ - QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", - QueryID: "01abc874-0406-1bf0-0000-53b10668e056", - SMKID: 92019681909886, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + mockEncMaterial1 := snowflakeFileEncryption{ + QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", + QueryID: "01abc874-0406-1bf0-0000-53b10668e056", + SMKID: 92019681909886, + } - mockEncMaterial2 := snowflakeFileEncryption{ - QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", - QueryID: "01abc874-0406-1bf0-0000-53b10668e056", - SMKID: 92019681909886, - } + mockEncMaterial2 := snowflakeFileEncryption{ + QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", + QueryID: "01abc874-0406-1bf0-0000-53b10668e056", + SMKID: 92019681909886, + } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: []string{"/tmp/uploads"}, - EncryptionMaterial: encryptionWrapper{ - snowflakeFileEncryption: mockEncMaterial1, - EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1, mockEncMaterial2}, + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: []string{"/tmp/uploads"}, + EncryptionMaterial: encryptionWrapper{ + snowflakeFileEncryption: mockEncMaterial1, + EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1, mockEncMaterial2}, + }, }, - }, - } + } - err = sfa.parseCommand() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrInternalNotMatchEncryptMaterial { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInternalNotMatchEncryptMaterial, driverErr.Number) - } + err := sfa.parseCommand() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrInternalNotMatchEncryptMaterial { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInternalNotMatchEncryptMaterial, driverErr.Number) + } + }) } func TestParseCommandInvalidStorageClientException(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - tmpDir, err := os.MkdirTemp("", "abc") - if err != nil { - t.Error(err) - } - mockEncMaterial1 := snowflakeFileEncryption{ - QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", - QueryID: "01abc874-0406-1bf0-0000-53b10668e056", - SMKID: 92019681909886, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + tmpDir, err := os.MkdirTemp("", "abc") + if err != nil { + t.Error(err) + } + mockEncMaterial1 := snowflakeFileEncryption{ + QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", + QueryID: "01abc874-0406-1bf0-0000-53b10668e056", + SMKID: 92019681909886, + } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: []string{"/tmp/uploads"}, - LocalLocation: tmpDir, - EncryptionMaterial: encryptionWrapper{ - snowflakeFileEncryption: mockEncMaterial1, - EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1}, + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: []string{"/tmp/uploads"}, + LocalLocation: tmpDir, + EncryptionMaterial: encryptionWrapper{ + snowflakeFileEncryption: mockEncMaterial1, + EncryptionMaterials: []snowflakeFileEncryption{mockEncMaterial1}, + }, }, - }, - options: &SnowflakeFileTransferOptions{ - DisablePutOverwrite: false, - }, - } + options: &SnowflakeFileTransferOptions{ + DisablePutOverwrite: false, + }, + } - err = sfa.parseCommand() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrInvalidStageFs { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageFs, driverErr.Number) - } + err = sfa.parseCommand() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrInvalidStageFs { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrInvalidStageFs, driverErr.Number) + } + }) } func TestInitFileMetadataError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: []string{"fileDoesNotExist.txt"}, - data: &execResponseData{ - SQLState: "123456", - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - }, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: []string{"fileDoesNotExist.txt"}, + data: &execResponseData{ + SQLState: "123456", + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + }, + } - err = sfa.initFileMetadata() - if err == nil { - t.Fatal("should have raised an error") - } + err := sfa.initFileMetadata() + if err == nil { + t.Fatal("should have raised an error") + } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrFileNotExists { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) - } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrFileNotExists { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) + } - tmpDir, err := os.MkdirTemp("", "data") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) - sfa.srcFiles = []string{tmpDir} + tmpDir, err := os.MkdirTemp("", "data") + if err != nil { + t.Error(err) + } + defer os.RemoveAll(tmpDir) + sfa.srcFiles = []string{tmpDir} - err = sfa.initFileMetadata() - if err == nil { - t.Fatal("should have raised an error") - } + err = sfa.initFileMetadata() + if err == nil { + t.Fatal("should have raised an error") + } - driverErr, ok = err.(*SnowflakeError) - if !ok || driverErr.Number != ErrFileNotExists { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) - } + driverErr, ok = err.(*SnowflakeError) + if !ok || driverErr.Number != ErrFileNotExists { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFileNotExists, driverErr.Number) + } + }) } func TestUpdateMetadataWithPresignedUrl(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - info := execResponseStageInfo{ - Location: "gcs-blob/storage/users/456/", - LocationType: "GCS", - } + runSnowflakeConnTest(t, func(sct *SCTest) { + info := execResponseStageInfo{ + Location: "gcs-blob/storage/users/456/", + LocationType: "GCS", + } - dir, err := os.Getwd() - if err != nil { - t.Error(err) - } + dir, err := os.Getwd() + if err != nil { + t.Error(err) + } - testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" - - presignedURLMock := func(_ context.Context, _ *snowflakeRestful, - _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, - requestID UUID, _ *Config) (*execResponse, error) { - // ensure the same requestID from context is used - if len(requestID) == 0 { - t.Fatal("requestID is empty") - } - dd := &execResponseData{ - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - Command: string(uploadCommand), - StageInfo: execResponseStageInfo{ - LocationType: "GCS", - Location: "gcspuscentral1-4506459564-stage/users/456", - Path: "users/456", - Region: "US_CENTRAL1", - PresignedURL: testURL, - }, + testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" + + presignedURLMock := func(_ context.Context, _ *snowflakeRestful, + _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, + requestID UUID, _ *Config) (*execResponse, error) { + // ensure the same requestID from context is used + if len(requestID) == 0 { + t.Fatal("requestID is empty") + } + dd := &execResponseData{ + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + Command: string(uploadCommand), + StageInfo: execResponseStageInfo{ + LocationType: "GCS", + Location: "gcspuscentral1-4506459564-stage/users/456", + Path: "users/456", + Region: "US_CENTRAL1", + PresignedURL: testURL, + }, + } + return &execResponse{ + Data: *dd, + Message: "", + Code: "0", + Success: true, + }, nil } - return &execResponse{ - Data: *dd, - Message: "", - Code: "0", - Success: true, - }, nil - } - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) - if err != nil { - t.Error(err) - } - uploadMeta := fileMetadata{ - name: "data1.txt.gz", - stageLocationType: "GCS", - noSleepingTime: true, - client: gcsCli, - sha256Digest: "123456789abcdef", - stageInfo: &info, - dstFileName: "data1.txt.gz", - srcFileName: path.Join(dir, "/test_data/data1.txt"), - overwrite: true, - options: &SnowflakeFileTransferOptions{ - MultiPartThreshold: dataSizeThreshold, - }, - } + gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + if err != nil { + t.Error(err) + } + uploadMeta := fileMetadata{ + name: "data1.txt.gz", + stageLocationType: "GCS", + noSleepingTime: true, + client: gcsCli, + sha256Digest: "123456789abcdef", + stageInfo: &info, + dstFileName: "data1.txt.gz", + srcFileName: path.Join(dir, "/test_data/data1.txt"), + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } - sc.rest.FuncPostQuery = presignedURLMock - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - command: "put file:///tmp/test_data/data1.txt @~", - stageLocationType: gcsClient, - fileMetadata: []*fileMetadata{&uploadMeta}, - } + sct.sc.rest.FuncPostQuery = presignedURLMock + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + command: "put file:///tmp/test_data/data1.txt @~", + stageLocationType: gcsClient, + fileMetadata: []*fileMetadata{&uploadMeta}, + } - err = sfa.updateFileMetadataWithPresignedURL() - if err != nil { - t.Error(err) - } - if testURL != sfa.fileMetadata[0].presignedURL.String() { - t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) - } + err = sfa.updateFileMetadataWithPresignedURL() + if err != nil { + t.Error(err) + } + if testURL != sfa.fileMetadata[0].presignedURL.String() { + t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) + } + }) } func TestUpdateMetadataWithPresignedUrlForDownload(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - info := execResponseStageInfo{ - Location: "gcs-blob/storage/users/456/", - LocationType: "GCS", - } + runSnowflakeConnTest(t, func(sct *SCTest) { + info := execResponseStageInfo{ + Location: "gcs-blob/storage/users/456/", + LocationType: "GCS", + } - dir, err := os.Getwd() - if err != nil { - t.Error(err) - } + dir, err := os.Getwd() + if err != nil { + t.Error(err) + } - testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" + testURL := "https://storage.google.com/gcs-blob/storage/users/456?Signature=testsignature123" - gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) - if err != nil { - t.Error(err) - } - downloadMeta := fileMetadata{ - name: "data1.txt.gz", - stageLocationType: "GCS", - noSleepingTime: true, - client: gcsCli, - stageInfo: &info, - dstFileName: "data1.txt.gz", - overwrite: true, - srcFileName: "data1.txt.gz", - localLocation: dir, - } + gcsCli, err := new(snowflakeGcsClient).createClient(&info, false) + if err != nil { + t.Error(err) + } + downloadMeta := fileMetadata{ + name: "data1.txt.gz", + stageLocationType: "GCS", + noSleepingTime: true, + client: gcsCli, + stageInfo: &info, + dstFileName: "data1.txt.gz", + overwrite: true, + srcFileName: "data1.txt.gz", + localLocation: dir, + } - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: downloadCommand, - command: "get @~/data1.txt.gz file:///tmp/testData", - stageLocationType: gcsClient, - fileMetadata: []*fileMetadata{&downloadMeta}, - presignedURLs: []string{testURL}, - } + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: downloadCommand, + command: "get @~/data1.txt.gz file:///tmp/testData", + stageLocationType: gcsClient, + fileMetadata: []*fileMetadata{&downloadMeta}, + presignedURLs: []string{testURL}, + } - err = sfa.updateFileMetadataWithPresignedURL() - if err != nil { - t.Error(err) - } - if testURL != sfa.fileMetadata[0].presignedURL.String() { - t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) - } + err = sfa.updateFileMetadataWithPresignedURL() + if err != nil { + t.Error(err) + } + if testURL != sfa.fileMetadata[0].presignedURL.String() { + t.Fatalf("failed to update metadata with presigned url. expected: %v. got: %v", testURL, sfa.fileMetadata[0].presignedURL.String()) + } + }) } func TestUpdateMetadataWithPresignedUrlError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - sfa := &snowflakeFileTransferAgent{ - sc: sc, - command: "get @~/data1.txt.gz file:///tmp/testData", - stageLocationType: gcsClient, - data: &execResponseData{ - SQLState: "123456", - QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", - }, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + command: "get @~/data1.txt.gz file:///tmp/testData", + stageLocationType: gcsClient, + data: &execResponseData{ + SQLState: "123456", + QueryID: "01aa2e8b-0405-ab7c-0000-53b10632f626", + }, + } - err = sfa.updateFileMetadataWithPresignedURL() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok || driverErr.Number != ErrCommandNotRecognized { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCommandNotRecognized, driverErr.Number) - } + err := sfa.updateFileMetadataWithPresignedURL() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok || driverErr.Number != ErrCommandNotRecognized { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCommandNotRecognized, driverErr.Number) + } + }) } func TestUploadWhenFilesystemReadOnlyError(t *testing.T) { diff --git a/heartbeat_test.go b/heartbeat_test.go index 839d50986..17235f57b 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -3,54 +3,43 @@ package gosnowflake import ( - "context" "testing" ) func TestUnitPostHeartbeat(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + // send heartbeat call and renew expired session + sr := &snowflakeRestful{ + FuncPost: postTestRenew, + FuncRenewSession: renewSessionTest, + TokenAccessor: getSimpleTokenAccessor(), + RequestTimeout: 0, + } + heartbeat := &heartbeat{ + restful: sr, + } + err := heartbeat.heartbeatMain() + if err != nil { + t.Fatalf("failed to heartbeat and renew session. err: %v", err) + } - // send heartbeat call and renew expired session - sr := &snowflakeRestful{ - FuncPost: postTestRenew, - FuncRenewSession: renewSessionTest, - TokenAccessor: getSimpleTokenAccessor(), - RequestTimeout: 0, - } - heartbeat := &heartbeat{ - restful: sr, - } - err = heartbeat.heartbeatMain() - if err != nil { - t.Fatalf("failed to heartbeat and renew session. err: %v", err) - } + heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON + err = heartbeat.heartbeatMain() + if err == nil { + t.Fatal("should have failed") + } - heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON - err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } - - heartbeat.restful.FuncPost = postTestAppForbiddenError - err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToHeartbeat { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToHeartbeat, driverErr.Number) - } + heartbeat.restful.FuncPost = postTestAppForbiddenError + err = heartbeat.heartbeatMain() + if err == nil { + t.Fatal("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrFailedToHeartbeat { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToHeartbeat, driverErr.Number) + } + }) } diff --git a/htap.go b/htap.go new file mode 100644 index 000000000..93806801f --- /dev/null +++ b/htap.go @@ -0,0 +1,91 @@ +package gosnowflake + +import ( + "sort" + "strconv" + "sync" +) + +const ( + queryContextCacheSizeParamName = "QUERY_CONTEXT_CACHE_SIZE" + defaultQueryContextCacheSize = 5 +) + +type queryContext struct { + Entries []queryContextEntry `json:"entries,omitempty"` +} + +type queryContextEntry struct { + ID int `json:"id"` + Timestamp int64 `json:"timestamp"` + Priority int `json:"priority"` + Context string `json:"context,omitempty"` +} + +type queryContextCache struct { + mutex *sync.Mutex + entries []queryContextEntry +} + +func (qcc *queryContextCache) init() *queryContextCache { + qcc.mutex = &sync.Mutex{} + return qcc +} + +func (qcc *queryContextCache) add(sc *snowflakeConn, qces ...queryContextEntry) { + qcc.mutex.Lock() + defer qcc.mutex.Unlock() + if len(qces) == 0 { + qcc.prune(0) + } else { + for _, newQce := range qces { + logger.Debugf("adding query context: %v", newQce) + newQceProcessed := false + for existingQceIdx, existingQce := range qcc.entries { + if newQce.ID == existingQce.ID { + newQceProcessed = true + if newQce.Timestamp > existingQce.Timestamp { + qcc.entries[existingQceIdx] = newQce + } else if newQce.Timestamp == existingQce.Timestamp { + if newQce.Priority != existingQce.Priority { + qcc.entries[existingQceIdx] = newQce + } + } + } + } + if !newQceProcessed { + for existingQceIdx, existingQce := range qcc.entries { + if newQce.Priority == existingQce.Priority { + qcc.entries[existingQceIdx] = newQce + newQceProcessed = true + } + } + } + if !newQceProcessed { + qcc.entries = append(qcc.entries, newQce) + } + } + sort.Slice(qcc.entries, func(idx1, idx2 int) bool { + return qcc.entries[idx1].Priority < qcc.entries[idx2].Priority + }) + qcc.prune(qcc.getQueryContextCacheSize(sc)) + } +} + +func (qcc *queryContextCache) prune(size int) { + if len(qcc.entries) > size { + qcc.entries = qcc.entries[0:size] + } +} + +func (qcc *queryContextCache) getQueryContextCacheSize(sc *snowflakeConn) int { + if sizeStr, ok := sc.cfg.Params[queryContextCacheSizeParamName]; ok { + size, err := strconv.Atoi(*sizeStr) + if err != nil { + logger.Warnf("cannot parse %v as int as query context cache size: %v", sizeStr, err) + } else { + return size + } + } + return defaultQueryContextCacheSize +} diff --git a/htap_test.go b/htap_test.go new file mode 100644 index 000000000..a724424f7 --- /dev/null +++ b/htap_test.go @@ -0,0 +1,411 @@ +package gosnowflake + +import ( + "database/sql/driver" + "fmt" + "reflect" + "testing" + "time" +) + +func TestSortingByPriority(t *testing.T) { + qcc := (&queryContextCache{}).init() + sc := htapTestSnowflakeConn() + + qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} + qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} + qceC := queryContextEntry{ID: 14, Timestamp: 125, Priority: 6, Context: "c"} + qceD := queryContextEntry{ID: 15, Timestamp: 126, Priority: 8, Context: "d"} + + t.Run("Add to empty cache", func(t *testing.T) { + qcc.add(sc, qceA) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA}) { + t.Fatalf("no entries added to cache. %v", qcc.entries) + } + }) + t.Run("Add another entry with different id, timestamp and priority - greater priority", func(t *testing.T) { + qcc.add(sc, qceB) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with different id, timestamp and priority - lesser priority", func(t *testing.T) { + qcc.add(sc, qceC) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceA, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with different id, timestamp and priority - priority in the middle", func(t *testing.T) { + qcc.add(sc, qceD) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceA, qceD, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) +} + +func TestAddingQcesWithTheSameIdAndLaterTimestamp(t *testing.T) { + qcc := (&queryContextCache{}).init() + sc := htapTestSnowflakeConn() + + qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} + qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} + qceC := queryContextEntry{ID: 12, Timestamp: 125, Priority: 6, Context: "c"} + qceD := queryContextEntry{ID: 12, Timestamp: 126, Priority: 6, Context: "d"} + + t.Run("Add to empty cache", func(t *testing.T) { + qcc.add(sc, qceA) + qcc.add(sc, qceB) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("no entries added to cache. %v", qcc.entries) + } + }) + t.Run("Add another entry with different priority", func(t *testing.T) { + qcc.add(sc, qceC) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with same priority", func(t *testing.T) { + qcc.add(sc, qceD) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceD, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) +} + +func TestAddingQcesWithTheSameIdAndSameTimestamp(t *testing.T) { + qcc := (&queryContextCache{}).init() + sc := htapTestSnowflakeConn() + + qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} + qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} + qceC := queryContextEntry{ID: 12, Timestamp: 123, Priority: 6, Context: "c"} + qceD := queryContextEntry{ID: 12, Timestamp: 123, Priority: 6, Context: "d"} + + t.Run("Add to empty cache", func(t *testing.T) { + qcc.add(sc, qceA) + qcc.add(sc, qceB) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("no entries added to cache. %v", qcc.entries) + } + }) + t.Run("Add another entry with different priority", func(t *testing.T) { + qcc.add(sc, qceC) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with same priority", func(t *testing.T) { + qcc.add(sc, qceD) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) +} + +func TestAddingQcesWithTheSameIdAndEarlierTimestamp(t *testing.T) { + qcc := (&queryContextCache{}).init() + sc := htapTestSnowflakeConn() + + qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} + qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} + qceC := queryContextEntry{ID: 12, Timestamp: 122, Priority: 6, Context: "c"} + qceD := queryContextEntry{ID: 12, Timestamp: 122, Priority: 7, Context: "d"} + + t.Run("Add to empty cache", func(t *testing.T) { + qcc.add(sc, qceA) + qcc.add(sc, qceB) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with different priority", func(t *testing.T) { + qcc.add(sc, qceC) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with same priority", func(t *testing.T) { + qcc.add(sc, qceD) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) +} + +func TestAddingQcesWithDifferentId(t *testing.T) { + qcc := (&queryContextCache{}).init() + sc := htapTestSnowflakeConn() + + qceA := queryContextEntry{ID: 12, Timestamp: 123, Priority: 7, Context: "a"} + qceB := queryContextEntry{ID: 13, Timestamp: 124, Priority: 9, Context: "b"} + qceC := queryContextEntry{ID: 14, Timestamp: 122, Priority: 7, Context: "c"} + qceD := queryContextEntry{ID: 15, Timestamp: 122, Priority: 6, Context: "d"} + + t.Run("Add to empty cache", func(t *testing.T) { + qcc.add(sc, qceA) + qcc.add(sc, qceB) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceA, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with same priority", func(t *testing.T) { + qcc.add(sc, qceC) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceC, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) + t.Run("Add another entry with different priority", func(t *testing.T) { + qcc.add(sc, qceD) + if !reflect.DeepEqual(qcc.entries, []queryContextEntry{qceD, qceC, qceB}) { + t.Fatalf("unexpected qcc entries. %v", qcc.entries) + } + }) +} + +func TestAddingQueryContextCacheEntry(t *testing.T) { + runSnowflakeConnTest(t, func(sct *SCTest) { + t.Run("First query (may be on empty cache)", func(t *testing.T) { + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) + sct.mustQuery("SELECT 1", nil) + entriesAfter := sct.sc.queryContextCache.entries + + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) + + t.Run("Second query (cache should not be empty)", func(t *testing.T) { + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) + if len(entriesBefore) == 0 { + t.Fatalf("cache should not be empty after first query") + } + sct.mustQuery("SELECT 2", nil) + entriesAfter := sct.sc.queryContextCache.entries + + if !containsNewEntries(entriesAfter, entriesBefore) { + t.Error("no new entries added to the query context cache") + } + }) + }) +} + +func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryContextEntry) bool { + if len(entriesAfter) > len(entriesBefore) { + return true + } + + for _, entryAfter := range entriesAfter { + for _, entryBefore := range entriesBefore { + if !reflect.DeepEqual(entryBefore, entryAfter) { + return true + } + } + } + + return false +} + +func TestPruneBySessionValue(t *testing.T) { + qce1 := queryContextEntry{1, 1, 1, ""} + qce2 := queryContextEntry{2, 2, 2, ""} + qce3 := queryContextEntry{3, 3, 3, ""} + + testcases := []struct { + size string + expected []queryContextEntry + }{ + { + size: "1", + expected: []queryContextEntry{qce1}, + }, + { + size: "2", + expected: []queryContextEntry{qce1, qce2}, + }, + { + size: "3", + expected: []queryContextEntry{qce1, qce2, qce3}, + }, + { + size: "4", + expected: []queryContextEntry{qce1, qce2, qce3}, + }, + } + + for _, tc := range testcases { + t.Run(tc.size, func(t *testing.T) { + sc := &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{ + queryContextCacheSizeParamName: &tc.size, + }, + }, + } + + qcc := (&queryContextCache{}).init() + + qcc.add(sc, qce1) + qcc.add(sc, qce2) + qcc.add(sc, qce3) + + if !reflect.DeepEqual(qcc.entries, tc.expected) { + t.Errorf("unexpected cache entries. expected: %v, got: %v", tc.expected, qcc.entries) + } + }) + } +} + +func TestPruneByDefaultValue(t *testing.T) { + qce1 := queryContextEntry{1, 1, 1, ""} + qce2 := queryContextEntry{2, 2, 2, ""} + qce3 := queryContextEntry{3, 3, 3, ""} + qce4 := queryContextEntry{4, 4, 4, ""} + qce5 := queryContextEntry{5, 5, 5, ""} + qce6 := queryContextEntry{6, 6, 6, ""} + + sc := &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{}, + }, + } + + qcc := (&queryContextCache{}).init() + qcc.add(sc, qce1) + qcc.add(sc, qce2) + qcc.add(sc, qce3) + qcc.add(sc, qce4) + qcc.add(sc, qce5) + + if len(qcc.entries) != 5 { + t.Fatalf("Expected 5 elements, got: %v", len(qcc.entries)) + } + + qcc.add(sc, qce6) + if len(qcc.entries) != 5 { + t.Fatalf("Expected 5 elements, got: %v", len(qcc.entries)) + } +} + +func TestNoQcesClearsCache(t *testing.T) { + qce1 := queryContextEntry{1, 1, 1, ""} + + sc := &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{}, + }, + } + + qcc := (&queryContextCache{}).init() + qcc.add(sc, qce1) + + if len(qcc.entries) != 1 { + t.Fatalf("improperly inited cache") + } + + qcc.add(sc) + + if len(qcc.entries) != 0 { + t.Errorf("after adding empty context list cache should be cleared") + } +} + +func htapTestSnowflakeConn() *snowflakeConn { + return &snowflakeConn{ + cfg: &Config{ + Params: map[string]*string{}, + }, + } +} + +func TestQueryContextCacheDisabled(t *testing.T) { + origDsn := dsn + defer func() { + dsn = origDsn + }() + dsn += "&disableQueryContextCache=true" + runSnowflakeConnTest(t, func(sct *SCTest) { + sct.mustExec("SELECT 1", nil) + if len(sct.sc.queryContextCache.entries) > 0 { + t.Error("should not contain any entries") + } + }) +} + +func TestHybridTablesE2E(t *testing.T) { + if runningOnGithubAction() && !runningOnAWS() { + t.Skip("HTAP is enabled only on AWS") + } + runID := time.Now().UnixMilli() + testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID) + testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID) + runSnowflakeConnTest(t, func(sct *SCTest) { + dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) + defer dbQuery.Close() + currentDb := make([]driver.Value, 1) + dbQuery.Next(currentDb) + defer func() { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil) + }() + + t.Run("Run tests on first database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) + rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows2.Close() + rows2.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + rows2.Next(row) + if row[0] != "2" || row[1] != "b" { + t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 2 { + t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on second database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil) + sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "3" || row[1] != "c" { + t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on first database again", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + }) +} diff --git a/multistatement.go b/multistatement.go index c8b13e21d..ce9d9910b 100644 --- a/multistatement.go +++ b/multistatement.go @@ -78,7 +78,7 @@ func (sc *snowflakeConn) handleMultiExec( return &snowflakeResult{ affectedRows: updatedRows, insertID: -1, - queryID: sc.QueryID, + queryID: data.QueryID, }, nil } diff --git a/multistatement_test.go b/multistatement_test.go index c7a5c5cae..ce0d713bf 100644 --- a/multistatement_test.go +++ b/multistatement_test.go @@ -23,10 +23,7 @@ func TestMultiStatementExecuteNoResultSet(t *testing.T) { "commit;" runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("drop table if exists test_multi_statement_txn") - dbt.mustExec(`create or replace table test_multi_statement_txn( - c1 number, c2 string) as select 10, 'z'`) - defer dbt.mustExec("drop table if exists test_multi_statement_txn") + dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`) res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() @@ -48,6 +45,7 @@ func TestMultiStatementQueryResultSet(t *testing.T) { var v1, v2, v3 int64 var v4 string + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close() @@ -481,115 +479,97 @@ func funcGetQueryRespError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ } func TestUnitHandleMultiExec(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - data := &execResponseData{ - ResultIDs: "", - ResultTypes: "", - } - _, err = sc.handleMultiExec(context.Background(), *data) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrNoResultIDs { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + data := execResponseData{ + ResultIDs: "", + ResultTypes: "", + } + _, err := sct.sc.handleMultiExec(context.Background(), data) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrNoResultIDs { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) + } - sr := &snowflakeRestful{ - FuncGet: funcGetQueryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - data = &execResponseData{ - ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", - ResultTypes: "12544,12544", - } - sc.rest = sr - _, err = sc.handleMultiExec(context.Background(), *data) - if err == nil { - t.Fatalf("should have failed") - } + data = execResponseData{ + ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", + ResultTypes: "12544,12544", + } + sct.sc.rest = &snowflakeRestful{ + FuncGet: funcGetQueryRespFail, + FuncCloseSession: closeSessionMock, + TokenAccessor: getSimpleTokenAccessor(), + } + _, err = sct.sc.handleMultiExec(context.Background(), data) + if err == nil { + t.Fatalf("should have failed") + } - sc.rest.FuncGet = funcGetQueryRespError - data.SQLState = "01112" - _, err = sc.handleMultiExec(context.Background(), *data) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok = err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToPostQuery { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) - } + sct.sc.rest.FuncGet = funcGetQueryRespError + data.SQLState = "01112" + _, err = sct.sc.handleMultiExec(context.Background(), data) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok = err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrFailedToPostQuery { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) + } + }) } func TestUnitHandleMultiQuery(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - data := &execResponseData{ - ResultIDs: "", - ResultTypes: "", - } - rows := new(snowflakeRows) - err = sc.handleMultiQuery(context.Background(), *data, rows) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrNoResultIDs { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) - } - sr := &snowflakeRestful{ - FuncGet: funcGetQueryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - data = &execResponseData{ - ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", - ResultTypes: "12544,12544", - } - sc.rest = sr - err = sc.handleMultiQuery(context.Background(), *data, rows) - if err == nil { - t.Fatalf("should have failed") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + data := execResponseData{ + ResultIDs: "", + ResultTypes: "", + } + rows := new(snowflakeRows) + err := sct.sc.handleMultiQuery(context.Background(), data, rows) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrNoResultIDs { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) + } + data = execResponseData{ + ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", + ResultTypes: "12544,12544", + } + sct.sc.rest = &snowflakeRestful{ + FuncGet: funcGetQueryRespFail, + FuncCloseSession: closeSessionMock, + TokenAccessor: getSimpleTokenAccessor(), + } + err = sct.sc.handleMultiQuery(context.Background(), data, rows) + if err == nil { + t.Fatalf("should have failed") + } - sc.rest.FuncGet = funcGetQueryRespError - data.SQLState = "01112" - err = sc.handleMultiQuery(context.Background(), *data, rows) - if err == nil { - t.Fatalf("should have failed") - } - driverErr, ok = err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToPostQuery { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) - } + sct.sc.rest.FuncGet = funcGetQueryRespError + data.SQLState = "01112" + err = sct.sc.handleMultiQuery(context.Background(), data, rows) + if err == nil { + t.Fatalf("should have failed") + } + driverErr, ok = err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrFailedToPostQuery { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) + } + }) } diff --git a/put_get_with_aws_test.go b/put_get_with_aws_test.go index b02fd27a8..b5753d295 100644 --- a/put_get_with_aws_test.go +++ b/put_get_with_aws_test.go @@ -87,115 +87,102 @@ func TestLoadS3(t *testing.T) { } func TestPutWithInvalidToken(t *testing.T) { - if !runningOnAWS() { - t.Skip("skipping non aws environment") - } - tmpDir, err := os.MkdirTemp("", "aws_put") - if err != nil { - t.Error(err) - } - defer os.RemoveAll(tmpDir) - fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") - originalContents := "123,test1\n456,test2\n" - - var b bytes.Buffer - gzw := gzip.NewWriter(&b) - gzw.Write([]byte(originalContents)) - gzw.Close() - if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { - t.Fatal("could not write to gzip file") - } - - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - tableName := randomString(5) - if _, err = sc.Exec("create or replace table "+tableName+ - " (a int, b string)", nil); err != nil { - t.Fatal(err) - } - defer sc.Exec("drop table "+tableName, nil) + runSnowflakeConnTest(t, func(sct *SCTest) { + if !runningOnAWS() { + t.Skip("skipping non aws environment") + } + tmpDir, err := os.MkdirTemp("", "aws_put") + if err != nil { + t.Error(err) + } + defer os.RemoveAll(tmpDir) + fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") + originalContents := "123,test1\n456,test2\n" + + var b bytes.Buffer + gzw := gzip.NewWriter(&b) + gzw.Write([]byte(originalContents)) + gzw.Close() + if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { + t.Fatal("could not write to gzip file") + } - jsonBody, err := json.Marshal(execRequest{ - SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), - }) - if err != nil { - t.Error(err) - } - headers := getHeaders() - headers[httpHeaderAccept] = headerContentTypeApplicationJSON - data, err := sc.rest.FuncPostQuery( - sc.ctx, sc.rest, &url.Values{}, headers, jsonBody, - sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sc.ctx), sc.cfg) - if err != nil { - t.Fatal(err) - } + tableName := randomString(5) + sct.mustExec("create or replace table "+tableName+" (a int, b string)", nil) + defer sct.mustExec("drop table "+tableName, nil) - s3Util := new(snowflakeS3Client) - s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) - if err != nil { - t.Error(err) - } - client := s3Cli.(*s3.Client) + jsonBody, err := json.Marshal(execRequest{ + SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), + }) + if err != nil { + t.Error(err) + } + headers := getHeaders() + headers[httpHeaderAccept] = headerContentTypeApplicationJSON + data, err := sct.sc.rest.FuncPostQuery( + sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody, + sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg) + if err != nil { + t.Fatal(err) + } - s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location) - if err != nil { - t.Error(err) - } - s3Path := s3Loc.s3Path + baseName(fname) + ".gz" + s3Util := new(snowflakeS3Client) + s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) + if err != nil { + t.Error(err) + } + client := s3Cli.(*s3.Client) - f, err := os.Open(fname) - if err != nil { - t.Error(err) - } - defer f.Close() - uploader := manager.NewUploader(client) - if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3Loc.bucketName, - Key: &s3Path, - Body: f, - }); err != nil { - t.Fatal(err) - } + s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location) + if err != nil { + t.Error(err) + } + s3Path := s3Loc.s3Path + baseName(fname) + ".gz" - parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/" - if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3Loc.bucketName, - Key: &parentPath, - Body: f, - }); err == nil { - t.Fatal("should have failed attempting to put file in parent path") - } + f, err := os.Open(fname) + if err != nil { + t.Error(err) + } + defer f.Close() + uploader := manager.NewUploader(client) + if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3Loc.bucketName, + Key: &s3Path, + Body: f, + }); err != nil { + t.Fatal(err) + } - info := execResponseStageInfo{ - Creds: execResponseCredentials{ - AwsID: data.Data.StageInfo.Creds.AwsID, - AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, - }, - } - s3Cli, err = s3Util.createClient(&info, false) - if err != nil { - t.Error(err) - } - client = s3Cli.(*s3.Client) + parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/" + if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3Loc.bucketName, + Key: &parentPath, + Body: f, + }); err == nil { + t.Fatal("should have failed attempting to put file in parent path") + } - uploader = manager.NewUploader(client) - if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ - Bucket: &s3Loc.bucketName, - Key: &s3Path, - Body: f, - }); err == nil { - t.Fatal("should have failed attempting to put with missing aws token") - } + info := execResponseStageInfo{ + Creds: execResponseCredentials{ + AwsID: data.Data.StageInfo.Creds.AwsID, + AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, + }, + } + s3Cli, err = s3Util.createClient(&info, false) + if err != nil { + t.Error(err) + } + client = s3Cli.(*s3.Client) + + uploader = manager.NewUploader(client) + if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ + Bucket: &s3Loc.bucketName, + Key: &s3Path, + Body: f, + }); err == nil { + t.Fatal("should have failed attempting to put with missing aws token") + } + }) } func TestPretendToPutButList(t *testing.T) { @@ -218,50 +205,38 @@ func TestPretendToPutButList(t *testing.T) { t.Fatal("could not write to gzip file") } - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + tableName := randomString(5) + sct.mustExec("create or replace table "+tableName+ + " (a int, b string)", nil) + defer sct.sc.Exec("drop table "+tableName, nil) - tableName := randomString(5) - if _, err = sc.Exec("create or replace table "+tableName+ - " (a int, b string)", nil); err != nil { - t.Fatal(err) - } - defer sc.Exec("drop table "+tableName, nil) + jsonBody, err := json.Marshal(execRequest{ + SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), + }) + if err != nil { + t.Error(err) + } + headers := getHeaders() + headers[httpHeaderAccept] = headerContentTypeApplicationJSON + data, err := sct.sc.rest.FuncPostQuery( + sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody, + sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg) + if err != nil { + t.Fatal(err) + } - jsonBody, err := json.Marshal(execRequest{ - SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), + s3Util := new(snowflakeS3Client) + s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) + if err != nil { + t.Error(err) + } + client := s3Cli.(*s3.Client) + if _, err = client.ListBuckets(context.Background(), + &s3.ListBucketsInput{}); err == nil { + t.Fatal("list buckets should fail") + } }) - if err != nil { - t.Error(err) - } - headers := getHeaders() - headers[httpHeaderAccept] = headerContentTypeApplicationJSON - data, err := sc.rest.FuncPostQuery( - sc.ctx, sc.rest, &url.Values{}, headers, jsonBody, - sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sc.ctx), sc.cfg) - if err != nil { - t.Fatal(err) - } - - s3Util := new(snowflakeS3Client) - s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) - if err != nil { - t.Error(err) - } - client := s3Cli.(*s3.Client) - if _, err = client.ListBuckets(context.Background(), - &s3.ListBucketsInput{}); err == nil { - t.Fatal("list buckets should fail") - } } func TestPutGetAWSStage(t *testing.T) { diff --git a/query.go b/query.go index db76d1624..300233d0e 100644 --- a/query.go +++ b/query.go @@ -3,6 +3,7 @@ package gosnowflake import ( + "encoding/json" "time" ) @@ -27,6 +28,22 @@ type execRequest struct { Parameters map[string]interface{} `json:"parameters,omitempty"` Bindings map[string]execBindParameter `json:"bindings,omitempty"` BindStage string `json:"bindStage,omitempty"` + QueryContext requestQueryContext `json:"queryContextDTO,omitempty"` +} + +type requestQueryContext struct { + Entries []requestQueryContextEntry `json:"entries,omitempty"` +} + +type requestQueryContextEntry struct { + Context contextData `json:"context,omitempty"` + ID int `json:"id"` + Priority int `json:"priority"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +type contextData struct { + Base64Data string `json:"base64Data,omitempty"` } type execResponseRowType struct { @@ -118,6 +135,9 @@ type execResponseData struct { Command string `json:"command,omitempty"` Kind string `json:"kind,omitempty"` Operation string `json:"operation,omitempty"` + + // HTAP + QueryContext json.RawMessage `json:"queryContext,omitempty"` } type execResponse struct { diff --git a/rows_test.go b/rows_test.go index bf414f825..fabb6c3bd 100644 --- a/rows_test.go +++ b/rows_test.go @@ -452,45 +452,25 @@ func TestDownloadChunkErrorStatus(t *testing.T) { func TestWithArrowBatchesNotImplementedForResult(t *testing.T) { ctx := WithArrowBatches(context.Background()) - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(ctx, *config) - if err != nil { - t.Error(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Error(err) - } - - if _, err = sc.Exec("create or replace table testArrowBatches (a int, b int)", nil); err != nil { - t.Fatal(err) - } - defer sc.Exec("drop table if exists testArrowBatches", nil) + runSnowflakeConnTest(t, func(sct *SCTest) { - result, err := sc.ExecContext(ctx, "insert into testArrowBatches values (1, 2), (3, 4), (5, 6)", []driver.NamedValue{}) - if err != nil { - t.Error(err) - } + sct.mustExec("create or replace table testArrowBatches (a int, b int)", nil) + defer sct.sc.Exec("drop table if exists testArrowBatches", nil) - result.(*snowflakeResult).GetStatus() - queryID := result.(*snowflakeResult).GetQueryID() - if queryID != sc.QueryID { - t.Fatalf("failed to get query ID") - } + result := sct.mustExecContext(ctx, "insert into testArrowBatches values (1, 2), (3, 4), (5, 6)", []driver.NamedValue{}) - _, err = result.(*snowflakeResult).GetArrowBatches() - if err == nil { - t.Fatal("should have raised an error") - } - driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrNotImplemented { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNotImplemented, driverErr.Number) - } + _, err := result.(*snowflakeResult).GetArrowBatches() + if err == nil { + t.Fatal("should have raised an error") + } + driverErr, ok := err.(*SnowflakeError) + if !ok { + t.Fatalf("should be snowflake error. err: %v", err) + } + if driverErr.Number != ErrNotImplemented { + t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNotImplemented, driverErr.Number) + } + }) } func TestLocationChangesAfterAlterSession(t *testing.T) { diff --git a/statement.go b/statement.go index 203009986..70d4479a7 100644 --- a/statement.go +++ b/statement.go @@ -7,9 +7,15 @@ import ( "database/sql/driver" ) +// SnowflakeStmt represents the prepared statement in driver. +type SnowflakeStmt interface { + GetQueryID() string +} + type snowflakeStmt struct { - sc *snowflakeConn - query string + sc *snowflakeConn + query string + lastQueryID string } func (stmt *snowflakeStmt) Close() error { @@ -26,20 +32,32 @@ func (stmt *snowflakeStmt) NumInput() int { func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext") - return stmt.sc.ExecContext(ctx, stmt.query, args) + result, err := stmt.sc.ExecContext(ctx, stmt.query, args) + stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + return result, err } func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext") - return stmt.sc.QueryContext(ctx, stmt.query, args) + rows, err := stmt.sc.QueryContext(ctx, stmt.query, args) + stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() + return rows, err } func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec") - return stmt.sc.Exec(stmt.query, args) + result, err := stmt.sc.Exec(stmt.query, args) + stmt.lastQueryID = result.(SnowflakeResult).GetQueryID() + return result, err } func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) { logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query") - return stmt.sc.Query(stmt.query, args) + rows, err := stmt.sc.Query(stmt.query, args) + stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID() + return rows, err +} + +func (stmt *snowflakeStmt) GetQueryID() string { + return stmt.lastQueryID } diff --git a/statement_test.go b/statement_test.go index de496c8aa..68f1e7a7b 100644 --- a/statement_test.go +++ b/statement_test.go @@ -1,4 +1,5 @@ // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved. +//lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility. package gosnowflake @@ -287,3 +288,129 @@ func TestUnitCheckQueryStatus(t *testing.T) { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number) } } + +func TestStatementQueryIdForQueries(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Rows, error) + }{ + { + "query", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.Query(nil) + }, + }, + { + "queryContext", + func(stmt driver.Stmt) (driver.Rows, error) { + return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondQuery, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestStatementQueryIdForExecs(t *testing.T) { + ctx := context.Background() + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") + defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") + + testcases := []struct { + name string + f func(stmt driver.Stmt) (driver.Result, error) + }{ + { + "exec", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.Exec(nil) + }, + }, + { + "execContext", + func(stmt driver.Stmt) (driver.Result, error) { + return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + err := dbt.conn.Raw(func(x any) error { + stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() != "" { + t.Error("queryId should be empty before executing any query") + } + firstExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + secondExec, err := tc.f(stmt) + if err != nil { + t.Fatal(err) + } + if stmt.(SnowflakeStmt).GetQueryID() == "" { + t.Error("queryId should not be empty after executing query") + } + if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() { + t.Error("queryId should be equal among query result and prepared statement") + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) + } + }) +} diff --git a/telemetry_test.go b/telemetry_test.go index 9a09c2bdb..c738de7e2 100644 --- a/telemetry_test.go +++ b/telemetry_test.go @@ -14,86 +14,65 @@ import ( ) func TestTelemetryAddLog(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - st := &snowflakeTelemetry{ - sr: sc.rest, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } - rand.Seed(time.Now().UnixNano()) - randNum := rand.Int() % 10000 - for i := 0; i < randNum; i++ { - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { + runSnowflakeConnTest(t, func(sct *SCTest) { + st := &snowflakeTelemetry{ + sr: sct.sc.rest, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + rand.Seed(time.Now().UnixNano()) + randNum := rand.Int() % 10000 + for i := 0; i < randNum; i++ { + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + } + if len(st.logs) != randNum%defaultFlushSize { + t.Errorf("length of remaining logs does not match. expected: %v, got: %v", + randNum%defaultFlushSize, len(st.logs)) + } + if err := st.sendBatch(); err != nil { t.Fatal(err) } - } - if len(st.logs) != randNum%defaultFlushSize { - t.Errorf("length of remaining logs does not match. expected: %v, got: %v", - randNum%defaultFlushSize, len(st.logs)) - } - if err = st.sendBatch(); err != nil { - t.Fatal(err) - } + }) } func TestTelemetrySQLException(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - st := &snowflakeTelemetry{ - sr: sc.rest, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } - sc.telemetry = st - sfa := &snowflakeFileTransferAgent{ - sc: sc, - commandType: uploadCommand, - srcFiles: make([]string, 0), - data: &execResponseData{ - SrcLocations: make([]string, 0), - }, - } - if err = sfa.initFileMetadata(); err == nil { - t.Fatal("this should have thrown an error") - } - if len(st.logs) != 1 { - t.Errorf("there should be 1 telemetry data in log. found: %v", len(st.logs)) - } - if sendErr := st.sendBatch(); sendErr != nil { - t.Fatal(sendErr) - } - if len(st.logs) != 0 { - t.Errorf("there should be no telemetry data in log. found: %v", len(st.logs)) - } + runSnowflakeConnTest(t, func(sct *SCTest) { + sct.sc.telemetry = &snowflakeTelemetry{ + sr: sct.sc.rest, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } + sfa := &snowflakeFileTransferAgent{ + sc: sct.sc, + commandType: uploadCommand, + srcFiles: make([]string, 0), + data: &execResponseData{ + SrcLocations: make([]string, 0), + }, + } + if err := sfa.initFileMetadata(); err == nil { + t.Fatal("this should have thrown an error") + } + if len(sct.sc.telemetry.logs) != 1 { + t.Errorf("there should be 1 telemetry data in log. found: %v", len(sct.sc.telemetry.logs)) + } + if sendErr := sct.sc.telemetry.sendBatch(); sendErr != nil { + t.Fatal(sendErr) + } + if len(sct.sc.telemetry.logs) != 0 { + t.Errorf("there should be no telemetry data in log. found: %v", len(sct.sc.telemetry.logs)) + } + }) } func TestDisableTelemetry(t *testing.T) { @@ -118,24 +97,14 @@ func TestDisableTelemetry(t *testing.T) { } func TestEnableTelemetry(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - config.DisableTelemetry = false - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - if sc.cfg.DisableTelemetry { - t.Errorf("DisableTelemetry should be false. DisableTelemetry: %v", sc.cfg.DisableTelemetry) - } - if !sc.telemetry.enabled { - t.Errorf("telemetry should be enabled.") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + if sct.sc.cfg.DisableTelemetry { + t.Errorf("DisableTelemetry should be false. DisableTelemetry: %v", sct.sc.cfg.DisableTelemetry) + } + if !sct.sc.telemetry.enabled { + t.Errorf("telemetry should be enabled.") + } + }) } func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) { @@ -143,203 +112,161 @@ func funcPostTelemetryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.UR } func TestTelemetryError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sr := &snowflakeRestful{ - FuncPost: funcPostTelemetryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + st := &snowflakeTelemetry{ + sr: &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + }, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } + err := st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + }) } func TestTelemetryDisabledOnBadResponse(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sr := &snowflakeRestful{ - FuncPost: postTestAppBadGatewayError, - TokenAccessor: getSimpleTokenAccessor(), - } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: defaultFlushSize, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + st := &snowflakeTelemetry{ + sr: &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + }, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: defaultFlushSize, + } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } - if st.enabled == true { - t.Fatal("telemetry should be disabled") - } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err := st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } - st.enabled = true - st.sr.FuncPost = postTestQueryNotExecuting - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } - if st.enabled == true { - t.Fatal("telemetry should be disabled") - } + st.enabled = true + st.sr.FuncPost = postTestQueryNotExecuting + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } - st.enabled = true - st.sr.FuncPost = postTestSuccessButInvalidJSON - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } - if st.enabled == true { - t.Fatal("telemetry should be disabled") - } + st.enabled = true + st.sr.FuncPost = postTestSuccessButInvalidJSON + if err = st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + err = st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + if st.enabled == true { + t.Fatal("telemetry should be disabled") + } + }) } func TestTelemetryDisabled(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - sr := &snowflakeRestful{ - FuncPost: postTestAppBadGatewayError, - TokenAccessor: getSimpleTokenAccessor(), - } - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: false, // disable - flushSize: defaultFlushSize, - } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err == nil { - t.Fatal("should have failed") - } - st.enabled = true - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err != nil { - t.Fatal(err) - } - st.enabled = false - err = st.sendBatch() - if err == nil { - t.Fatal("should have failed") - } + runSnowflakeConnTest(t, func(sct *SCTest) { + st := &snowflakeTelemetry{ + sr: &snowflakeRestful{ + FuncPost: postTestAppBadGatewayError, + TokenAccessor: getSimpleTokenAccessor(), + }, + mutex: &sync.Mutex{}, + enabled: false, // disable + flushSize: defaultFlushSize, + } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } + st.enabled = true + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err != nil { + t.Fatal(err) + } + st.enabled = false + err := st.sendBatch() + if err == nil { + t.Fatal("should have failed") + } + }) } func TestAddLogError(t *testing.T) { - config, err := ParseDSN(dsn) - if err != nil { - t.Error(err) - } - sc, err := buildSnowflakeConn(context.Background(), *config) - if err != nil { - t.Fatal(err) - } - if err = authenticateWithConfig(sc); err != nil { - t.Fatal(err) - } - - sr := &snowflakeRestful{ - FuncPost: funcPostTelemetryRespFail, - TokenAccessor: getSimpleTokenAccessor(), - } - - st := &snowflakeTelemetry{ - sr: sr, - mutex: &sync.Mutex{}, - enabled: true, - flushSize: 1, - } + runSnowflakeConnTest(t, func(sct *SCTest) { + st := &snowflakeTelemetry{ + sr: &snowflakeRestful{ + FuncPost: funcPostTelemetryRespFail, + TokenAccessor: getSimpleTokenAccessor(), + }, + mutex: &sync.Mutex{}, + enabled: true, + flushSize: 1, + } - if err = st.addLog(&telemetryData{ - Message: map[string]string{ - typeKey: "client_telemetry_type", - queryIDKey: "123", - }, - Timestamp: time.Now().UnixNano() / int64(time.Millisecond), - }); err == nil { - t.Fatal("should have failed") - } + if err := st.addLog(&telemetryData{ + Message: map[string]string{ + typeKey: "client_telemetry_type", + queryIDKey: "123", + }, + Timestamp: time.Now().UnixNano() / int64(time.Millisecond), + }); err == nil { + t.Fatal("should have failed") + } + }) }