From 2081db728127f8917df99637e5fe67a624eea9e0 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Mon, 24 Jul 2023 14:17:32 +0200 Subject: [PATCH] SNOW-857631 Handle multistatement query type --- arrow_test.go | 21 +++++++++++++++++++++ connection.go | 3 ++- connection_util.go | 4 ++-- multistatement_test.go | 34 ++++++++++++++++++++++------------ 4 files changed, 47 insertions(+), 15 deletions(-) diff --git a/arrow_test.go b/arrow_test.go index 1234a6765..cca07603d 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -13,6 +13,27 @@ import ( "time" ) +//A test just to show Snowflake version +func TestCheckVersion(t *testing.T) { + conn := openConn(t) + defer conn.Close() + + rows, err := conn.QueryContext(context.Background(), "SELECT current_version()") + if err != nil { + t.Error(err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatalf("failed to find any row") + } + var s string + if err = rows.Scan(&s); err != nil { + t.Fatal(err) + } + println(s) +} + func TestArrowBigInt(t *testing.T) { conn := openConn(t) defer conn.Close() diff --git a/connection.go b/connection.go index aa26e8532..2941ac9fe 100644 --- a/connection.go +++ b/connection.go @@ -37,9 +37,10 @@ const ( ) const ( - statementTypeIDMulti = int64(0x1000) + statementTypeIDSelect = int64(0x1000) statementTypeIDDml = int64(0x3000) statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500) + statementTypeIDMultistatement = int64(0xA000) ) const ( diff --git a/connection_util.go b/connection_util.go index 2a45e363c..ac1061f67 100644 --- a/connection_util.go +++ b/connection_util.go @@ -213,8 +213,8 @@ func updateRows(data execResponseData) (int64, error) { // Note that the statement type code is also equivalent to type INSERT, so an // additional check of the name is required func isMultiStmt(data *execResponseData) bool { - return data.StatementTypeID == statementTypeIDMulti && - data.RowType[0].Name == "multiple statement execution" + var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution" + return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement } func getResumeQueryID(ctx context.Context) (string, error) { diff --git a/multistatement_test.go b/multistatement_test.go index c7a5c5cae..f61185e37 100644 --- a/multistatement_test.go +++ b/multistatement_test.go @@ -6,6 +6,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "net/url" @@ -22,11 +23,8 @@ func TestMultiStatementExecuteNoResultSet(t *testing.T) { "insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" + "commit;" - runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("drop table if exists test_multi_statement_txn") - dbt.mustExec(`create or replace table test_multi_statement_txn( - c1 number, c2 string) as select 10, 'z'`) - defer dbt.mustExec("drop table if exists test_multi_statement_txn") + testForAllMultistatementTypes(t, func(dbt *DBTest) { + dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`) res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() @@ -48,7 +46,8 @@ func TestMultiStatementQueryResultSet(t *testing.T) { var v1, v2, v3 int64 var v4 string - runDBTest(t, func(dbt *DBTest) { + + testForAllMultistatementTypes(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close() @@ -120,7 +119,7 @@ func TestMultiStatementExecuteResultSet(t *testing.T) { "select 2;\n" + "rollback;" - runDBTest(t, func(dbt *DBTest) { + testForAllMultistatementTypes(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn_rb") dbt.mustExec(`create or replace table test_multi_statement_txn_rb( c1 number, c2 string) as select 10, 'z'`) @@ -144,7 +143,7 @@ func TestMultiStatementQueryNoResultSet(t *testing.T) { "insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" + "commit;" - runDBTest(t, func(dbt *DBTest) { + testForAllMultistatementTypes(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn") dbt.mustExec(`create or replace table test_multi_statement_txn( c1 number, c2 string) as select 10, 'z'`) @@ -161,7 +160,7 @@ func TestMultiStatementExecuteMix(t *testing.T) { "insert into test_multi values (1), (2);\n" + "select cola from test_multi order by cola asc;" - runDBTest(t, func(dbt *DBTest) { + testForAllMultistatementTypes(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn") dbt.mustExec(`create or replace table test_multi_statement_txn( c1 number, c2 string) as select 10, 'z'`) @@ -185,7 +184,7 @@ func TestMultiStatementQueryMix(t *testing.T) { "select cola from test_multi order by cola asc;" var count, v int - runDBTest(t, func(dbt *DBTest) { + testForAllMultistatementTypes(t, func(dbt *DBTest) { dbt.mustExec("drop table if exists test_multi_statement_txn") dbt.mustExec(`create or replace table test_multi_statement_txn( c1 number, c2 string) as select 10, 'z'`) @@ -232,7 +231,7 @@ func TestMultiStatementCountZero(t *testing.T) { var v3 float64 var v4 bool - runDBTest(t, func(dbt *DBTest) { + testForAllMultistatementTypes(t, func(dbt *DBTest) { // first query multiStmtQuery1 := "select 123;\n" + "select '456';" @@ -352,7 +351,7 @@ func TestMultiStatementVaryingColumnCount(t *testing.T) { ctx, _ := WithMultiStatement(context.Background(), 0) var v1, v2 int - runDBTest(t, func(dbt *DBTest) { + testForAllMultistatementTypes(t, func(dbt *DBTest) { dbt.mustExec("create or replace table test_tbl(c1 int, c2 int)") dbt.mustExec("insert into test_tbl values(1, 0)") defer dbt.mustExec("drop table if exists test_tbl") @@ -593,3 +592,14 @@ func TestUnitHandleMultiQuery(t *testing.T) { t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) } } + +func testForAllMultistatementTypes(t *testing.T, test func(dbt *DBTest)) { + for _, enableMultistatementType := range []bool{false, true} { + t.Run(fmt.Sprintf("enableMultistatementType=%v", enableMultistatementType), func(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + dbt.mustExec(fmt.Sprintf("ALTER SESSION SET ENABLE_MULTI_STMT_QUERY_TYPE = %v", enableMultistatementType)) + test(dbt) + }) + }) + } +}