Skip to content

Commit

Permalink
SNOW-895537: Send query context with request
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Sep 4, 2023
1 parent eecf8bf commit 74affb8
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 11 deletions.
45 changes: 45 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -85,13 +86,18 @@ func (sc *snowflakeConn) exec(
var err error
counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter

queryContext, err := buildQueryContext(sc.queryContextCache)
if err != nil {
logger.Warnf("error while building query context: %v", err)
}
req := execRequest{
SQLText: query,
AsyncExec: noResult,
Parameters: map[string]interface{}{},
IsInternal: isInternal,
DescribeOnly: describeOnly,
SequenceID: counter,
QueryContext: queryContext,
}
if key := ctx.Value(multiStatementCount); key != nil {
req.Parameters[string(multiStatementCount)] = key
Expand Down Expand Up @@ -160,6 +166,45 @@ func (sc *snowflakeConn) exec(
return data, err
}

func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) {
rqc := requestQueryContext{}
if qcc == nil || len(qcc.entries) == 0 {
logger.Debugf("empty qcc")
return rqc, nil
}
for _, qce := range qcc.entries {
contextBase64, err := marshallContext(qce.Context)
if err != nil {
return rqc, err
}
rqc.Entries = append(rqc.Entries, requestQueryContextEntry{
ID: qce.ID,
Priority: qce.Priority,
Timestamp: qce.Timestamp,
Context: contextData{
Base64Data: contextBase64,
},
})
}
return rqc, nil
}

func marshallContext(context any) (string, error) {
if context == nil {
return "", nil
}
contextJSON, err := json.Marshal(context)
if err != nil {
return "", fmt.Errorf("cannot serialize query context to JSON: %v, %v", err, context)
}
var contextBase64 []byte
_, err = base64.StdEncoding.Decode(contextBase64, contextJSON)
if err != nil {
return "", fmt.Errorf("cannot decode to base64: %v", err)
}
return string(contextBase64), nil
}

func (sc *snowflakeConn) Begin() (driver.Tx, error) {
return sc.BeginTx(sc.ctx, driver.TxOptions{})
}
Expand Down
35 changes: 33 additions & 2 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"crypto/rsa"
"database/sql"
"database/sql/driver"
"flag"
"fmt"
"net/http"
Expand Down Expand Up @@ -320,6 +321,34 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) {
return stmt
}

type SCTest struct {
*testing.T
sc *snowflakeConn
}

func (sct *SCTest) fail(method, query string, err error) {
if len(query) > 300 {
query = "[query too large to print]"
}
sct.Fatalf("error on %s [%s]: %s", method, query, err.Error())
}

func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result {
result, err := sct.sc.Exec(query, args)
if err != nil {
sct.fail("exec", query, err)
}
return result
}

func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows {
rows, err := sct.sc.Query(query, args)
if err != nil {
sct.fail("query", query, err)
}
return rows
}

func runDBTest(t *testing.T, test func(dbt *DBTest)) {
conn := openConn(t)
defer conn.Close()
Expand All @@ -328,7 +357,7 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) {
test(dbt)
}

func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) {
func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) {
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
Expand All @@ -342,7 +371,9 @@ func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) {
t.Fatal(err)
}

test(sc)
sct := &SCTest{t, sc}

test(sct)
}

func runningOnAWS() bool {
Expand Down
97 changes: 88 additions & 9 deletions htap_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package gosnowflake

import (
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strings"
"testing"
"time"
)

func TestMarshallAndDecodeOpaqueContext(t *testing.T) {
Expand Down Expand Up @@ -257,30 +260,30 @@ func TestAddingQcesWithDifferentId(t *testing.T) {
}

func TestAddingQueryContextCacheEntry(t *testing.T) {
runSnowflakeConnTest(t, func(sc *snowflakeConn) {
runSnowflakeConnTest(t, func(sct *SCTest) {
t.Run("First query (may be on empty cache)", func(t *testing.T) {
entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries))
copy(entriesBefore, sc.queryContextCache.entries)
if _, err := sc.Query("SELECT 1", nil); err != nil {
entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries))
copy(entriesBefore, sct.sc.queryContextCache.entries)
if _, err := sct.sc.Query("SELECT 1", nil); err != nil {
t.Fatalf("cannot query. %v", err)
}
entriesAfter := sc.queryContextCache.entries
entriesAfter := sct.sc.queryContextCache.entries

if !containsNewEntries(entriesAfter, entriesBefore) {
t.Error("no new entries added to the query context cache")
}
})

t.Run("Second query (cache should not be empty)", func(t *testing.T) {
entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries))
copy(entriesBefore, sc.queryContextCache.entries)
entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries))
copy(entriesBefore, sct.sc.queryContextCache.entries)
if len(entriesBefore) == 0 {
t.Fatalf("cache should not be empty after first query")
}
if _, err := sc.Query("SELECT 2", nil); err != nil {
if _, err := sct.sc.Query("SELECT 2", nil); err != nil {
t.Fatalf("cannot query. %v", err)
}
entriesAfter := sc.queryContextCache.entries
entriesAfter := sct.sc.queryContextCache.entries

if !containsNewEntries(entriesAfter, entriesBefore) {
t.Error("no new entries added to the query context cache")
Expand Down Expand Up @@ -416,3 +419,79 @@ func htapTestSnowflakeConn() *snowflakeConn {
},
}
}

func TestHybridTablesE2E(t *testing.T) {
if runningOnGithubAction() && !runningOnAWS() {
t.Skip("HTAP is enabled only on AWS")
}
runID := time.Now().UnixMilli()
testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID)
testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID)
runSnowflakeConnTest(t, func(sct *SCTest) {
dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil)
defer dbQuery.Close()
currentDb := make([]driver.Value, 1)
dbQuery.Next(currentDb)
defer func() {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil)
}()

t.Run("Run tests on first database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil)
rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}

sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil)
rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows2.Close()
rows2.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}
rows2.Next(row)
if row[0] != "2" || row[1] != "b" {
t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 2 {
t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on second database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil)
sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "3" || row[1] != "c" {
t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on first database again", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
})
}
16 changes: 16 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ type execRequest struct {
Parameters map[string]interface{} `json:"parameters,omitempty"`
Bindings map[string]execBindParameter `json:"bindings,omitempty"`
BindStage string `json:"bindStage,omitempty"`
QueryContext requestQueryContext `json:"queryContextDTO,omitempty"`
}

type requestQueryContext struct {
Entries []requestQueryContextEntry `json:"entries,omitempty"`
}

type requestQueryContextEntry struct {
Context contextData `json:"context,omitempty"`
ID int `json:"id"`
Priority int `json:"priority"`
Timestamp int64 `json:"timestamp,omitempty"`
}

type contextData struct {
Base64Data string `json:"base64Data,omitempty"`
}

type execResponseRowType struct {
Expand Down

0 comments on commit 74affb8

Please sign in to comment.