diff --git a/connection.go b/connection.go index a48af3600..326a269d4 100644 --- a/connection.go +++ b/connection.go @@ -11,6 +11,7 @@ import ( "database/sql/driver" "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -85,6 +86,10 @@ func (sc *snowflakeConn) exec( var err error counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter + queryContext, err := buildQueryContext(sc.queryContextCache) + if err != nil { + logger.Warnf("error while building query context: %v", err) + } req := execRequest{ SQLText: query, AsyncExec: noResult, @@ -92,6 +97,7 @@ func (sc *snowflakeConn) exec( IsInternal: isInternal, DescribeOnly: describeOnly, SequenceID: counter, + QueryContext: queryContext, } if key := ctx.Value(multiStatementCount); key != nil { req.Parameters[string(multiStatementCount)] = key @@ -160,6 +166,45 @@ func (sc *snowflakeConn) exec( return data, err } +func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) { + rqc := requestQueryContext{} + if qcc == nil || len(qcc.entries) == 0 { + logger.Debugf("empty qcc") + return rqc, nil + } + for _, qce := range qcc.entries { + contextBase64, err := marshallContext(qce.Context) + if err != nil { + return rqc, err + } + rqc.Entries = append(rqc.Entries, requestQueryContextEntry{ + ID: qce.ID, + Priority: qce.Priority, + Timestamp: qce.Timestamp, + Context: contextData{ + Base64Data: contextBase64, + }, + }) + } + return rqc, nil +} + +func marshallContext(context any) (string, error) { + if context == nil { + return "", nil + } + contextJSON, err := json.Marshal(context) + if err != nil { + return "", fmt.Errorf("cannot serialize query context to JSON: %v, %v", err, context) + } + var contextBase64 []byte + _, err = base64.StdEncoding.Decode(contextBase64, contextJSON) + if err != nil { + return "", fmt.Errorf("cannot decode to base64: %v", err) + } + return string(contextBase64), nil +} + func (sc *snowflakeConn) Begin() (driver.Tx, error) { return sc.BeginTx(sc.ctx, driver.TxOptions{}) } diff --git a/driver_test.go b/driver_test.go index 653c62777..9d471d784 100644 --- a/driver_test.go +++ b/driver_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/rsa" "database/sql" + "database/sql/driver" "flag" "fmt" "net/http" @@ -320,6 +321,34 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { return stmt } +type SCTest struct { + *testing.T + sc *snowflakeConn +} + +func (sct *SCTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) +} + +func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result { + result, err := sct.sc.Exec(query, args) + if err != nil { + sct.fail("exec", query, err) + } + return result +} + +func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows { + rows, err := sct.sc.Query(query, args) + if err != nil { + sct.fail("query", query, err) + } + return rows +} + func runDBTest(t *testing.T, test func(dbt *DBTest)) { conn := openConn(t) defer conn.Close() @@ -328,7 +357,7 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) { test(dbt) } -func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { +func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { config, err := ParseDSN(dsn) if err != nil { t.Error(err) @@ -342,7 +371,9 @@ func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) { t.Fatal(err) } - test(sc) + sct := &SCTest{t, sc} + + test(sct) } func runningOnAWS() bool { diff --git a/htap_test.go b/htap_test.go index e2e71eb20..f00990a22 100644 --- a/htap_test.go +++ b/htap_test.go @@ -1,10 +1,13 @@ package gosnowflake import ( + "database/sql/driver" "encoding/json" + "fmt" "reflect" "strings" "testing" + "time" ) func TestMarshallAndDecodeOpaqueContext(t *testing.T) { @@ -257,14 +260,14 @@ func TestAddingQcesWithDifferentId(t *testing.T) { } func TestAddingQueryContextCacheEntry(t *testing.T) { - runSnowflakeConnTest(t, func(sc *snowflakeConn) { + runSnowflakeConnTest(t, func(sct *SCTest) { t.Run("First query (may be on empty cache)", func(t *testing.T) { - entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries)) - copy(entriesBefore, sc.queryContextCache.entries) - if _, err := sc.Query("SELECT 1", nil); err != nil { + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) + if _, err := sct.sc.Query("SELECT 1", nil); err != nil { t.Fatalf("cannot query. %v", err) } - entriesAfter := sc.queryContextCache.entries + entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") @@ -272,15 +275,15 @@ func TestAddingQueryContextCacheEntry(t *testing.T) { }) t.Run("Second query (cache should not be empty)", func(t *testing.T) { - entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries)) - copy(entriesBefore, sc.queryContextCache.entries) + entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries)) + copy(entriesBefore, sct.sc.queryContextCache.entries) if len(entriesBefore) == 0 { t.Fatalf("cache should not be empty after first query") } - if _, err := sc.Query("SELECT 2", nil); err != nil { + if _, err := sct.sc.Query("SELECT 2", nil); err != nil { t.Fatalf("cannot query. %v", err) } - entriesAfter := sc.queryContextCache.entries + entriesAfter := sct.sc.queryContextCache.entries if !containsNewEntries(entriesAfter, entriesBefore) { t.Error("no new entries added to the query context cache") @@ -416,3 +419,79 @@ func htapTestSnowflakeConn() *snowflakeConn { }, } } + +func TestHybridTablesE2E(t *testing.T) { + if runningOnGithubAction() && !runningOnAWS() { + t.Skip("HTAP is enabled only on AWS") + } + runID := time.Now().UnixMilli() + testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID) + testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID) + runSnowflakeConnTest(t, func(sct *SCTest) { + dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) + defer dbQuery.Close() + currentDb := make([]driver.Value, 1) + dbQuery.Next(currentDb) + defer func() { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil) + }() + + t.Run("Run tests on first database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) + rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows2.Close() + rows2.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + rows2.Next(row) + if row[0] != "2" || row[1] != "b" { + t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 2 { + t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on second database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil) + sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "3" || row[1] != "c" { + t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on first database again", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + }) +} diff --git a/query.go b/query.go index 5d7dff053..edafdf990 100644 --- a/query.go +++ b/query.go @@ -27,6 +27,22 @@ type execRequest struct { Parameters map[string]interface{} `json:"parameters,omitempty"` Bindings map[string]execBindParameter `json:"bindings,omitempty"` BindStage string `json:"bindStage,omitempty"` + QueryContext requestQueryContext `json:"queryContextDTO,omitempty"` +} + +type requestQueryContext struct { + Entries []requestQueryContextEntry `json:"entries,omitempty"` +} + +type requestQueryContextEntry struct { + Context contextData `json:"context,omitempty"` + ID int `json:"id"` + Priority int `json:"priority"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +type contextData struct { + Base64Data string `json:"base64Data,omitempty"` } type execResponseRowType struct {