Skip to content

Commit

Permalink
SNOW-857631 Handle multistatement query type
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Aug 11, 2023
1 parent 652242e commit e39de92
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
3 changes: 2 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ const (
)

const (
statementTypeIDMulti = int64(0x1000)
statementTypeIDSelect = int64(0x1000)
statementTypeIDDml = int64(0x3000)
statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500)
statementTypeIDMultistatement = int64(0xA000)
)

const (
Expand Down
4 changes: 2 additions & 2 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ func updateRows(data execResponseData) (int64, error) {
// Note that the statement type code is also equivalent to type INSERT, so an
// additional check of the name is required
func isMultiStmt(data *execResponseData) bool {
return data.StatementTypeID == statementTypeIDMulti &&
data.RowType[0].Name == "multiple statement execution"
var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution"
return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement
}

func getResumeQueryID(ctx context.Context) (string, error) {
Expand Down
34 changes: 22 additions & 12 deletions multistatement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
Expand All @@ -22,11 +23,8 @@ func TestMultiStatementExecuteNoResultSet(t *testing.T) {
"insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" +
"commit;"

runDBTest(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
defer dbt.mustExec("drop table if exists test_multi_statement_txn")
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`)

res := dbt.mustExecContext(ctx, multiStmtQuery)
count, err := res.RowsAffected()
Expand All @@ -48,7 +46,8 @@ func TestMultiStatementQueryResultSet(t *testing.T) {

var v1, v2, v3 int64
var v4 string
runDBTest(t, func(dbt *DBTest) {

testForAllMultistatementTypes(t, func(dbt *DBTest) {
rows := dbt.mustQueryContext(ctx, multiStmtQuery)
defer rows.Close()

Expand Down Expand Up @@ -120,7 +119,7 @@ func TestMultiStatementExecuteResultSet(t *testing.T) {
"select 2;\n" +
"rollback;"

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn_rb")
dbt.mustExec(`create or replace table test_multi_statement_txn_rb(
c1 number, c2 string) as select 10, 'z'`)
Expand All @@ -144,7 +143,7 @@ func TestMultiStatementQueryNoResultSet(t *testing.T) {
"insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" +
"commit;"

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
Expand All @@ -161,7 +160,7 @@ func TestMultiStatementExecuteMix(t *testing.T) {
"insert into test_multi values (1), (2);\n" +
"select cola from test_multi order by cola asc;"

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
Expand All @@ -185,7 +184,7 @@ func TestMultiStatementQueryMix(t *testing.T) {
"select cola from test_multi order by cola asc;"

var count, v int
runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("drop table if exists test_multi_statement_txn")
dbt.mustExec(`create or replace table test_multi_statement_txn(
c1 number, c2 string) as select 10, 'z'`)
Expand Down Expand Up @@ -232,7 +231,7 @@ func TestMultiStatementCountZero(t *testing.T) {
var v3 float64
var v4 bool

runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
// first query
multiStmtQuery1 := "select 123;\n" +
"select '456';"
Expand Down Expand Up @@ -352,7 +351,7 @@ func TestMultiStatementVaryingColumnCount(t *testing.T) {
ctx, _ := WithMultiStatement(context.Background(), 0)

var v1, v2 int
runDBTest(t, func(dbt *DBTest) {
testForAllMultistatementTypes(t, func(dbt *DBTest) {
dbt.mustExec("create or replace table test_tbl(c1 int, c2 int)")
dbt.mustExec("insert into test_tbl values(1, 0)")
defer dbt.mustExec("drop table if exists test_tbl")
Expand Down Expand Up @@ -591,3 +590,14 @@ func TestUnitHandleMultiQuery(t *testing.T) {
t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number)
}
}

func testForAllMultistatementTypes(t *testing.T, test func(dbt *DBTest)) {
for _, enableMultistatementType := range []bool{false, true} {
t.Run(fmt.Sprintf("enableMultistatementType=%v", enableMultistatementType), func(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
dbt.mustExec(fmt.Sprintf("ALTER SESSION SET ENABLE_MULTI_STMT_QUERY_TYPE = %v", enableMultistatementType))
test(dbt)
})
})
}
}

0 comments on commit e39de92

Please sign in to comment.