Skip to content

Commit

Permalink
SNOW-911146 Extract method of preparing connection in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pmotacki committed Sep 8, 2023
1 parent eecf8bf commit ba92972
Show file tree
Hide file tree
Showing 11 changed files with 1,238 additions and 1,570 deletions.
417 changes: 195 additions & 222 deletions chunk_test.go

Large diffs are not rendered by default.

303 changes: 114 additions & 189 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,70 +398,49 @@ func TestPrivateLink(t *testing.T) {
}

func TestGetQueryStatus(t *testing.T) {
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
ctx := context.Background()
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}
runSnowflakeConnTest(t, func(sct *SCTest) {
ctx := sct.sc.ctx

if _, err = sc.Exec(`create or replace table ut_conn(c1 number, c2 string)
sct.mustExec(`create or replace table ut_conn(c1 number, c2 string)
as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 from
table(generator(rowcount => 100)))`, nil); err != nil {
t.Error(err)
}
table(generator(rowcount => 100)))`, nil)

rows, err := sc.QueryContext(ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil)
if err != nil {
t.Error(err)
}
qid := rows.(SnowflakeResult).GetQueryID()
rows, err := sct.sc.QueryContext(ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil)
if err != nil {
t.Error(err)
}
qid := rows.(SnowflakeResult).GetQueryID()

// use conn as type holder for SnowflakeConnection placeholder
var conn interface{} = sc
qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(ctx, qid)
if err != nil {
t.Errorf("failed to get query status err = %s", err.Error())
return
}
if qStatus == nil {
t.Error("there was no query status returned")
return
}
// use conn as type holder for SnowflakeConnection placeholder
var conn interface{} = sct.sc
qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(ctx, qid)
if err != nil {
t.Errorf("failed to get query status err = %s", err.Error())
return
}
if qStatus == nil {
t.Error("there was no query status returned")
return
}

if qStatus.ErrorCode != "" || qStatus.ScanBytes != 2048 || qStatus.ProducedRows != 10 {
t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v",
qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows)
return
}
if qStatus.ErrorCode != "" || qStatus.ScanBytes != 2048 || qStatus.ProducedRows != 10 {
t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v",
qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows)
return
}
})
}

func TestGetInvalidQueryStatus(t *testing.T) {
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
ctx := context.Background()
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}

sc.rest.RequestTimeout = 1 * time.Second
runSnowflakeConnTest(t, func(sct *SCTest) {
ctx := sct.sc.ctx
sct.sc.rest.RequestTimeout = 1 * time.Second

qStatus, err := sc.checkQueryStatus(ctx, "1234")
if err == nil || qStatus != nil {
t.Error("expected an error")
}
qStatus, err := sct.sc.checkQueryStatus(ctx, "1234")
if err == nil || qStatus != nil {
t.Error("expected an error")
}
})
}

func TestExecWithServerSideError(t *testing.T) {
Expand Down Expand Up @@ -575,156 +554,102 @@ func executeQueryAndConfirmMessage(db *sql.DB, query string, expectedErrorTable
}

func TestQueryArrowStreamError(t *testing.T) {
ctx := context.Background()
numrows := 50000 // approximately 10 ArrowBatch objects
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}

query := fmt.Sprintf(selectRandomGenerator, numrows)
sr := &snowflakeRestful{
FuncPostQuery: postQueryTest,
TokenAccessor: getSimpleTokenAccessor(),
RequestTimeout: 10,
}
sc.rest = sr
_, err = sc.QueryArrowStream(ctx, query)
if err == nil {
t.Error("should have raised an error")
}
runSnowflakeConnTest(t, func(sct *SCTest) {
sc := sct.sc
numrows := 50000 // approximately 10 ArrowBatch objects
query := fmt.Sprintf(selectRandomGenerator, numrows)
sr := &snowflakeRestful{
FuncPostQuery: postQueryTest,
FuncCloseSession: closeSessionMock,
TokenAccessor: getSimpleTokenAccessor(),
RequestTimeout: 10,
}
sc.rest = sr
_, err := sc.QueryArrowStream(sc.ctx, query)
if err == nil {
t.Error("should have raised an error")
}

sc.rest.FuncPostQuery = postQueryFail
_, err = sc.QueryArrowStream(ctx, query)
if err == nil {
t.Error("should have raised an error")
}
_, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
sc.rest.FuncPostQuery = postQueryFail
_, err = sc.QueryArrowStream(sc.ctx, query)
if err == nil {
t.Error("should have raised an error")
}
_, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
})
}

func TestExecContextError(t *testing.T) {
ctx := context.Background()
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}

sr := &snowflakeRestful{
FuncPostQuery: postQueryTest,
TokenAccessor: getSimpleTokenAccessor(),
RequestTimeout: 10,
}

sc.rest = sr
runSnowflakeConnTest(t, func(sct *SCTest) {
sc := sct.sc
ctx := sc.ctx
sr := &snowflakeRestful{
FuncPostQuery: postQueryTest,
FuncCloseSession: closeSessionMock,
TokenAccessor: getSimpleTokenAccessor(),
RequestTimeout: 10,
}
sc.rest = sr

_, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}
_, err := sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}

sc.rest.FuncPostQuery = postQueryFail
_, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}
_, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
sc.rest.FuncPostQuery = postQueryFail
_, err = sc.ExecContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}
})
}

func TestQueryContextError(t *testing.T) {
ctx := context.Background()
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}

sr := &snowflakeRestful{
FuncPostQuery: postQueryTest,
TokenAccessor: getSimpleTokenAccessor(),
RequestTimeout: 10,
}

sc.rest = sr
runSnowflakeConnTest(t, func(sct *SCTest) {
ctx := sct.sc.ctx
sr := &snowflakeRestful{
FuncPostQuery: postQueryTest,
FuncCloseSession: closeSessionMock,
TokenAccessor: getSimpleTokenAccessor(),
RequestTimeout: 10,
}

_, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}
sct.sc.rest = sr
_, err := sct.sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}

sc.rest.FuncPostQuery = postQueryFail
_, err = sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}
_, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
sct.sc.rest.FuncPostQuery = postQueryFail
_, err = sct.sc.QueryContext(ctx, "SELECT 1", []driver.NamedValue{})
if err == nil {
t.Fatalf("should have raised an error")
}
_, ok := err.(*SnowflakeError)
if !ok {
t.Fatalf("should be snowflake error. err: %v", err)
}
})
}

func TestPrepareQuery(t *testing.T) {
ctx := context.Background()
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}
_, err = sc.Prepare("SELECT 1")
runSnowflakeConnTest(t, func(sct *SCTest) {
_, err := sct.sc.Prepare("SELECT 1")

if err != nil {
t.Fatalf("failed to prepare query. err: %v", err)
}
sc.Close()
if err != nil {
t.Fatalf("failed to prepare query. err: %v", err)
}
})
}

func TestBeginCreatesTransaction(t *testing.T) {
ctx := context.Background()
config, err := ParseDSN(dsn)
if err != nil {
t.Error(err)
}
sc, err := buildSnowflakeConn(ctx, *config)
if err != nil {
t.Error(err)
}
if err = authenticateWithConfig(sc); err != nil {
t.Error(err)
}
tx, _ := sc.Begin()
if tx == nil {
t.Fatal("should have created a transaction with connection")
}
sc.Close()
runSnowflakeConnTest(t, func(sct *SCTest) {
tx, _ := sct.sc.Begin()
if tx == nil {
t.Fatal("should have created a transaction with connection")
}
})
}
Loading

0 comments on commit ba92972

Please sign in to comment.