diff --git a/connection.go b/connection.go index a48af3600..148e435d7 100644 --- a/connection.go +++ b/connection.go @@ -85,6 +85,10 @@ func (sc *snowflakeConn) exec( var err error counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter + queryContext, err := buildQueryContext(sc.queryContextCache) + if err != nil { + logger.Errorf("error while building query context: %v", err) + } req := execRequest{ SQLText: query, AsyncExec: noResult, @@ -92,6 +96,7 @@ func (sc *snowflakeConn) exec( IsInternal: isInternal, DescribeOnly: describeOnly, SequenceID: counter, + QueryContext: queryContext, } if key := ctx.Value(multiStatementCount); key != nil { req.Parameters[string(multiStatementCount)] = key @@ -160,6 +165,27 @@ func (sc *snowflakeConn) exec( return data, err } +func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) { + rqc := requestQueryContext{} + if qcc == nil || len(qcc.entries) == 0 { + logger.Debugf("empty qcc") + return rqc, nil + } + for _, qce := range qcc.entries { + contextData := contextData{} + if qce.Context == "" { + contextData.Base64Data = qce.Context + } + rqc.Entries = append(rqc.Entries, requestQueryContextEntry{ + ID: qce.ID, + Priority: qce.Priority, + Timestamp: qce.Timestamp, + Context: contextData, + }) + } + return rqc, nil +} + func (sc *snowflakeConn) Begin() (driver.Tx, error) { return sc.BeginTx(sc.ctx, driver.TxOptions{}) } diff --git a/driver_test.go b/driver_test.go index 653c62777..9d471d784 100644 --- a/driver_test.go +++ b/driver_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/rsa" "database/sql" + "database/sql/driver" "flag" "fmt" "net/http" @@ -320,6 +321,34 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { return stmt } +type SCTest struct { + *testing.T + sc *snowflakeConn +} + +func (sct *SCTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) +} + +func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result { + result, err := sct.sc.Exec(query, args) + if err != nil { + sct.fail("exec", query, err) + } + return result +} + +func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows { + rows, err := sct.sc.Query(query, args) + if err != nil { + sct.fail("query", query, err) + } + return rows +} + func runDBTest(t *testing.T, test func(dbt *DBTest)) { conn := openConn(t) defer conn.Close() @@ -328,7 +357,7 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) { test(dbt) } -func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { +func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { config, err := ParseDSN(dsn) if err != nil { t.Error(err) @@ -342,7 +371,9 @@ func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { t.Fatal(err) } - test(sc) + sct := &SCTest{t, sc} + + test(sct) } func runningOnAWS() bool { diff --git a/htap.go b/htap.go index 6c33d4f48..22d4bbce6 100644 --- a/htap.go +++ b/htap.go @@ -12,10 +12,10 @@ const ( ) type queryContextEntry struct { - ID int `json:"id"` - Timestamp int64 `json:"timestamp"` - Priority int `json:"priority"` - Context any `json:"context,omitempty"` + ID int `json:"id"` + Timestamp int64 `json:"timestamp"` + Priority int `json:"priority"` + Context string `json:"context,omitempty"` } type queryContextCache struct { diff --git a/htap_test.go b/htap_test.go index e2e71eb20..0a5c7cb49 100644 --- a/htap_test.go +++ b/htap_test.go @@ -1,106 +1,13 @@ package gosnowflake import ( - "encoding/json" + "database/sql/driver" + "fmt" "reflect" - "strings" "testing" + "time" ) -func TestMarshallAndDecodeOpaqueContext(t *testing.T) { - testcases := []struct { - json string - qc queryContextEntry - }{ - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3 - }`, - qc: queryContextEntry{1, 2, 3, nil}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": "abc" - }`, - qc: queryContextEntry{1, 2, 3, "abc"}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": { - "val": "abc" - } - }`, - qc: queryContextEntry{1, 2, 3, map[string]interface{}{"val": "abc"}}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": [ - "abc" - ] - }`, - qc: queryContextEntry{1, 2, 3, []any{"abc"}}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": [ - { - "val": "abc" - } - ] - }`, - qc: queryContextEntry{1, 2, 3, []any{map[string]interface{}{"val": "abc"}}}, - }, - } - - for _, tc := range testcases { - t.Run(trimWhitespaces(tc.json), func(t *testing.T) { - var qc queryContextEntry - - err := json.NewDecoder(strings.NewReader(tc.json)).Decode(&qc) - if err != nil { - t.Fatalf("failed to decode json. %v", err) - } - - if !reflect.DeepEqual(tc.qc, qc) { - t.Errorf("failed to decode json. expected: %v, got: %v", tc.qc, qc) - } - - bytes, err := json.Marshal(qc) - if err != nil { - t.Fatalf("failed to encode json. %v", err) - } - - resultJSON := string(bytes) - if resultJSON != trimWhitespaces(tc.json) { - t.Errorf("failed to encode json. epxected: %v, got: %v", trimWhitespaces(tc.json), resultJSON) - } - }) - } -} - -func trimWhitespaces(s string) string { - return strings.ReplaceAll( - strings.ReplaceAll( - strings.ReplaceAll(s, "\t", ""), - " ", ""), - "\n", "", - ) -} - func TestSortingByPriority(t *testing.T) { qcc := (&queryContextCache{}).init() sc := htapTestSnowflakeConn() @@ -257,14 +164,14 @@ func TestAddingQcesWithDifferentId(t *testing.T) { } func TestAddingQueryContextCacheEntry(t *testing.T) { - runSnowflakeConnTest(t, func(sc *snowflakeConn) { + runSnowflakeConnTest(t, func(sct *SCTest) { t.Run("First query (may be on empty cache)", func(t *testing.T) { - entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries)) - copy(entriesBefore, sc.queryContextCache.entries) - if _, err := sc.Query("SELECT 1", nil); err != nil { + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) + if _, err := sct.sc.Query("SELECT 1", nil); err != nil { t.Fatalf("cannot query. %v", err) } - entriesAfter := sc.queryContextCache.entries + entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") @@ -272,15 +179,15 @@ func TestAddingQueryContextCacheEntry(t *testing.T) { }) t.Run("Second query (cache should not be empty)", func(t *testing.T) { - entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries)) - copy(entriesBefore, sc.queryContextCache.entries) + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) if len(entriesBefore) == 0 { t.Fatalf("cache should not be empty after first query") } - if _, err := sc.Query("SELECT 2", nil); err != nil { + if _, err := sct.sc.Query("SELECT 2", nil); err != nil { t.Fatalf("cannot query. %v", err) } - entriesAfter := sc.queryContextCache.entries + entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") @@ -306,9 +213,9 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC } func TestPruneBySessionValue(t *testing.T) { - qce1 := queryContextEntry{1, 1, 1, nil} - qce2 := queryContextEntry{2, 2, 2, nil} - qce3 := queryContextEntry{3, 3, 3, nil} + qce1 := queryContextEntry{1, 1, 1, ""} + qce2 := queryContextEntry{2, 2, 2, ""} + qce3 := queryContextEntry{3, 3, 3, ""} testcases := []struct { size string @@ -356,12 +263,12 @@ func TestPruneBySessionValue(t *testing.T) { } func TestPruneByDefaultValue(t *testing.T) { - qce1 := queryContextEntry{1, 1, 1, nil} - qce2 := queryContextEntry{2, 2, 2, nil} - qce3 := queryContextEntry{3, 3, 3, nil} - qce4 := queryContextEntry{4, 4, 4, nil} - qce5 := queryContextEntry{5, 5, 5, nil} - qce6 := queryContextEntry{6, 6, 6, nil} + qce1 := queryContextEntry{1, 1, 1, ""} + qce2 := queryContextEntry{2, 2, 2, ""} + qce3 := queryContextEntry{3, 3, 3, ""} + qce4 := queryContextEntry{4, 4, 4, ""} + qce5 := queryContextEntry{5, 5, 5, ""} + qce6 := queryContextEntry{6, 6, 6, ""} sc := &snowflakeConn{ cfg: &Config{ @@ -387,7 +294,7 @@ func TestPruneByDefaultValue(t *testing.T) { } func TestNoQcesClearsCache(t *testing.T) { - qce1 := queryContextEntry{1, 1, 1, nil} + qce1 := queryContextEntry{1, 1, 1, ""} sc := &snowflakeConn{ cfg: &Config{ @@ -416,3 +323,79 @@ func htapTestSnowflakeConn() *snowflakeConn { }, } } + +func TestHybridTablesE2E(t *testing.T) { + if runningOnGithubAction() && !runningOnAWS() { + t.Skip("HTAP is enabled only on AWS") + } + runID := time.Now().UnixMilli() + testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID) + testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID) + runSnowflakeConnTest(t, func(sct *SCTest) { + dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) + defer dbQuery.Close() + currentDb := make([]driver.Value, 1) + dbQuery.Next(currentDb) + defer func() { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil) + }() + + t.Run("Run tests on first database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) + rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows2.Close() + rows2.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + rows2.Next(row) + if row[0] != "2" || row[1] != "b" { + t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 2 { + t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on second database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil) + sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "3" || row[1] != "c" { + t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on first database again", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + }) +} diff --git a/query.go b/query.go index 5d7dff053..edafdf990 100644 --- a/query.go +++ b/query.go @@ -27,6 +27,22 @@ type execRequest struct { Parameters map[string]interface{} `json:"parameters,omitempty"` Bindings map[string]execBindParameter `json:"bindings,omitempty"` BindStage string `json:"bindStage,omitempty"` + QueryContext requestQueryContext `json:"queryContextDTO,omitempty"` +} + +type requestQueryContext struct { + Entries []requestQueryContextEntry `json:"entries,omitempty"` +} + +type requestQueryContextEntry struct { + Context contextData `json:"context,omitempty"` + ID int `json:"id"` + Priority int `json:"priority"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +type contextData struct { + Base64Data string `json:"base64Data,omitempty"` } type execResponseRowType struct {