Skip to content

Commit

Permalink
SNOW-895537: Send query context with request
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Sep 7, 2023
1 parent eecf8bf commit 640dfc1
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 121 deletions.
26 changes: 26 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ func (sc *snowflakeConn) exec(
var err error
counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter

queryContext, err := buildQueryContext(sc.queryContextCache)
if err != nil {
logger.Errorf("error while building query context: %v", err)
}
req := execRequest{
SQLText: query,
AsyncExec: noResult,
Parameters: map[string]interface{}{},
IsInternal: isInternal,
DescribeOnly: describeOnly,
SequenceID: counter,
QueryContext: queryContext,
}
if key := ctx.Value(multiStatementCount); key != nil {
req.Parameters[string(multiStatementCount)] = key
Expand Down Expand Up @@ -160,6 +165,27 @@ func (sc *snowflakeConn) exec(
return data, err
}

func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) {
rqc := requestQueryContext{}
if qcc == nil || len(qcc.entries) == 0 {
logger.Debugf("empty qcc")
return rqc, nil
}
for _, qce := range qcc.entries {
contextData := contextData{}
if qce.Context == "" {
contextData.Base64Data = qce.Context
}
rqc.Entries = append(rqc.Entries, requestQueryContextEntry{
ID: qce.ID,
Priority: qce.Priority,
Timestamp: qce.Timestamp,
Context: contextData,
})
}
return rqc, nil
}

func (sc *snowflakeConn) Begin() (driver.Tx, error) {
return sc.BeginTx(sc.ctx, driver.TxOptions{})
}
Expand Down
35 changes: 33 additions & 2 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"crypto/rsa"
"database/sql"
"database/sql/driver"
"flag"
"fmt"
"net/http"
Expand Down Expand Up @@ -320,6 +321,34 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) {
return stmt
}

type SCTest struct {
*testing.T
sc *snowflakeConn
}

func (sct *SCTest) fail(method, query string, err error) {
if len(query) > 300 {
query = "[query too large to print]"
}
sct.Fatalf("error on %s [%s]: %s", method, query, err.Error())
}

func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result {
result, err := sct.sc.Exec(query, args)
if err != nil {
sct.fail("exec", query, err)
}
return result
}

func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows {
rows, err := sct.sc.Query(query, args)
if err != nil {
sct.fail("query", query, err)
}
return rows
}

func runDBTest(t *testing.T, test func(dbt *DBTest)) {
conn := openConn(t)
defer conn.Close()
Expand All @@ -328,7 +357,7 @@ func runDBTest(t *testing.T, test func(dbt *DBTest)) {
test(dbt)
}

func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) {
func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) {
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
Expand All @@ -342,7 +371,9 @@ func runSnowflakeConnTest(t *testing.T, test func(sc *snowflakeConn)) {
t.Fatal(err)
}

test(sc)
sct := &SCTest{t, sc}

test(sct)
}

func runningOnAWS() bool {
Expand Down
8 changes: 4 additions & 4 deletions htap.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ const (
)

type queryContextEntry struct {
ID int `json:"id"`
Timestamp int64 `json:"timestamp"`
Priority int `json:"priority"`
Context any `json:"context,omitempty"`
ID int `json:"id"`
Timestamp int64 `json:"timestamp"`
Priority int `json:"priority"`
Context string `json:"context,omitempty"`
}

type queryContextCache struct {
Expand Down
213 changes: 98 additions & 115 deletions htap_test.go
Original file line number Diff line number Diff line change
@@ -1,106 +1,13 @@
package gosnowflake

import (
"encoding/json"
"database/sql/driver"
"fmt"
"reflect"
"strings"
"testing"
"time"
)

func TestMarshallAndDecodeOpaqueContext(t *testing.T) {
testcases := []struct {
json string
qc queryContextEntry
}{
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3
}`,
qc: queryContextEntry{1, 2, 3, nil},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": "abc"
}`,
qc: queryContextEntry{1, 2, 3, "abc"},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": {
"val": "abc"
}
}`,
qc: queryContextEntry{1, 2, 3, map[string]interface{}{"val": "abc"}},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": [
"abc"
]
}`,
qc: queryContextEntry{1, 2, 3, []any{"abc"}},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": [
{
"val": "abc"
}
]
}`,
qc: queryContextEntry{1, 2, 3, []any{map[string]interface{}{"val": "abc"}}},
},
}

for _, tc := range testcases {
t.Run(trimWhitespaces(tc.json), func(t *testing.T) {
var qc queryContextEntry

err := json.NewDecoder(strings.NewReader(tc.json)).Decode(&qc)
if err != nil {
t.Fatalf("failed to decode json. %v", err)
}

if !reflect.DeepEqual(tc.qc, qc) {
t.Errorf("failed to decode json. expected: %v, got: %v", tc.qc, qc)
}

bytes, err := json.Marshal(qc)
if err != nil {
t.Fatalf("failed to encode json. %v", err)
}

resultJSON := string(bytes)
if resultJSON != trimWhitespaces(tc.json) {
t.Errorf("failed to encode json. epxected: %v, got: %v", trimWhitespaces(tc.json), resultJSON)
}
})
}
}

func trimWhitespaces(s string) string {
return strings.ReplaceAll(
strings.ReplaceAll(
strings.ReplaceAll(s, "\t", ""),
" ", ""),
"\n", "",
)
}

func TestSortingByPriority(t *testing.T) {
qcc := (&queryContextCache{}).init()
sc := htapTestSnowflakeConn()
Expand Down Expand Up @@ -257,30 +164,30 @@ func TestAddingQcesWithDifferentId(t *testing.T) {
}

func TestAddingQueryContextCacheEntry(t *testing.T) {
runSnowflakeConnTest(t, func(sc *snowflakeConn) {
runSnowflakeConnTest(t, func(sct *SCTest) {
t.Run("First query (may be on empty cache)", func(t *testing.T) {
entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries))
copy(entriesBefore, sc.queryContextCache.entries)
if _, err := sc.Query("SELECT 1", nil); err != nil {
entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries))
copy(entriesBefore, sct.sc.queryContextCache.entries)
if _, err := sct.sc.Query("SELECT 1", nil); err != nil {
t.Fatalf("cannot query. %v", err)
}
entriesAfter := sc.queryContextCache.entries
entriesAfter := sct.sc.queryContextCache.entries

if !containsNewEntries(entriesAfter, entriesBefore) {
t.Error("no new entries added to the query context cache")
}
})

t.Run("Second query (cache should not be empty)", func(t *testing.T) {
entriesBefore := make([]queryContextEntry, len(sc.queryContextCache.entries))
copy(entriesBefore, sc.queryContextCache.entries)
entriesBefore := make([]queryContextEntry, len(sct.sc.queryContextCache.entries))
copy(entriesBefore, sct.sc.queryContextCache.entries)
if len(entriesBefore) == 0 {
t.Fatalf("cache should not be empty after first query")
}
if _, err := sc.Query("SELECT 2", nil); err != nil {
if _, err := sct.sc.Query("SELECT 2", nil); err != nil {
t.Fatalf("cannot query. %v", err)
}
entriesAfter := sc.queryContextCache.entries
entriesAfter := sct.sc.queryContextCache.entries

if !containsNewEntries(entriesAfter, entriesBefore) {
t.Error("no new entries added to the query context cache")
Expand All @@ -306,9 +213,9 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC
}

func TestPruneBySessionValue(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce2 := queryContextEntry{2, 2, 2, nil}
qce3 := queryContextEntry{3, 3, 3, nil}
qce1 := queryContextEntry{1, 1, 1, ""}
qce2 := queryContextEntry{2, 2, 2, ""}
qce3 := queryContextEntry{3, 3, 3, ""}

testcases := []struct {
size string
Expand Down Expand Up @@ -356,12 +263,12 @@ func TestPruneBySessionValue(t *testing.T) {
}

func TestPruneByDefaultValue(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce2 := queryContextEntry{2, 2, 2, nil}
qce3 := queryContextEntry{3, 3, 3, nil}
qce4 := queryContextEntry{4, 4, 4, nil}
qce5 := queryContextEntry{5, 5, 5, nil}
qce6 := queryContextEntry{6, 6, 6, nil}
qce1 := queryContextEntry{1, 1, 1, ""}
qce2 := queryContextEntry{2, 2, 2, ""}
qce3 := queryContextEntry{3, 3, 3, ""}
qce4 := queryContextEntry{4, 4, 4, ""}
qce5 := queryContextEntry{5, 5, 5, ""}
qce6 := queryContextEntry{6, 6, 6, ""}

sc := &snowflakeConn{
cfg: &Config{
Expand All @@ -387,7 +294,7 @@ func TestPruneByDefaultValue(t *testing.T) {
}

func TestNoQcesClearsCache(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce1 := queryContextEntry{1, 1, 1, ""}

sc := &snowflakeConn{
cfg: &Config{
Expand Down Expand Up @@ -416,3 +323,79 @@ func htapTestSnowflakeConn() *snowflakeConn {
},
}
}

func TestHybridTablesE2E(t *testing.T) {
if runningOnGithubAction() && !runningOnAWS() {
t.Skip("HTAP is enabled only on AWS")
}
runID := time.Now().UnixMilli()
testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID)
testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID)
runSnowflakeConnTest(t, func(sct *SCTest) {
dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil)
defer dbQuery.Close()
currentDb := make([]driver.Value, 1)
dbQuery.Next(currentDb)
defer func() {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil)
}()

t.Run("Run tests on first database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil)
rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}

sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil)
rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows2.Close()
rows2.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}
rows2.Next(row)
if row[0] != "2" || row[1] != "b" {
t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 2 {
t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on second database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil)
sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "3" || row[1] != "c" {
t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on first database again", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
})
}
Loading

0 comments on commit 640dfc1

Please sign in to comment.