diff --git a/aaa_test.go b/aaa_test.go new file mode 100644 index 000000000..de86ae1f0 --- /dev/null +++ b/aaa_test.go @@ -0,0 +1,15 @@ +package gosnowflake + +import "testing" + +func TestShowServerVersion(t *testing.T) { + runDBTest(t, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT CURRENT_VERSION()") + defer rows.Close() + + var version string + rows.Next() + rows.Scan(&version) + println(version) + }) +} diff --git a/connection.go b/connection.go index 4ea8dbd77..9b83d229f 100644 --- a/connection.go +++ b/connection.go @@ -159,10 +159,18 @@ func (sc *snowflakeConn) exec( } logger.WithContext(ctx).Info("Exec/Query SUCCESS") - sc.cfg.Database = data.Data.FinalDatabaseName - sc.cfg.Schema = data.Data.FinalSchemaName - sc.cfg.Role = data.Data.FinalRoleName - sc.cfg.Warehouse = data.Data.FinalWarehouseName + if data.Data.FinalDatabaseName != "" { + sc.cfg.Database = data.Data.FinalDatabaseName + } + if data.Data.FinalSchemaName != "" { + sc.cfg.Schema = data.Data.FinalSchemaName + } + if data.Data.FinalWarehouseName != "" { + sc.cfg.Warehouse = data.Data.FinalWarehouseName + } + if data.Data.FinalRoleName != "" { + sc.cfg.Role = data.Data.FinalRoleName + } sc.populateSessionParameters(data.Data.Parameters) return data, err } diff --git a/driver_test.go b/driver_test.go index fadbe3a55..1d6662ebe 100644 --- a/driver_test.go +++ b/driver_test.go @@ -387,6 +387,7 @@ func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { } sct := &SCTest{t, sc} + test(sct) } diff --git a/htap_test.go b/htap_test.go index 8aae85810..fd3ace374 100644 --- a/htap_test.go +++ b/htap_test.go @@ -1,10 +1,14 @@ package gosnowflake import ( + "context" "encoding/json" + "fmt" "reflect" + "strconv" "strings" "testing" + "time" ) func TestMarshallAndDecodeOpaqueContext(t *testing.T) { @@ -426,3 +430,165 @@ func TestQueryContextCacheDisabled(t *testing.T) { } }) } + +func TestHTAPOptimizations(t *testing.T) { + for _, useHtapOptimizations := range []bool{true, false} { + runSnowflakeConnTest(t, func(sct *SCTest) { + t.Run("useHtapOptimizations="+strconv.FormatBool(useHtapOptimizations), func(t *testing.T) { + if useHtapOptimizations { + sct.mustExec("ALTER SESSION SET ENABLE_SNOW_654741_FOR_TESTING = true", nil) + } + runID := time.Now().UnixMilli() + t.Run("Schema", func(t *testing.T) { + newSchema := fmt.Sprintf("test_schema_%v", runID) + if strings.EqualFold(sct.sc.cfg.Schema, newSchema) { + t.Errorf("schema should not be switched") + } + + sct.mustExec(fmt.Sprintf("CREATE SCHEMA %v", newSchema), nil) + defer sct.mustExec(fmt.Sprintf("DROP SCHEMA %v", newSchema), nil) + + if !strings.EqualFold(sct.sc.cfg.Schema, newSchema) { + t.Errorf("schema should be switched, expected %v, got %v", newSchema, sct.sc.cfg.Schema) + } + + query := sct.mustQuery("SELECT 1", nil) + query.Close() + + if !strings.EqualFold(sct.sc.cfg.Schema, newSchema) { + t.Errorf("schema should be switched, expected %v, got %v", newSchema, sct.sc.cfg.Schema) + } + }) + t.Run("Database", func(t *testing.T) { + newDatabase := fmt.Sprintf("test_database_%v", runID) + if strings.EqualFold(sct.sc.cfg.Database, newDatabase) { + t.Errorf("database should not be switched") + } + + // TODO replace with mustExec + sct.mustExec(fmt.Sprintf("CREATE DATABASE %v", newDatabase), nil) + defer sct.mustExec(fmt.Sprintf("DROP DATABASE %v", newDatabase), nil) + + if !strings.EqualFold(sct.sc.cfg.Database, newDatabase) { + t.Errorf("database should be switched, expected %v, got %v", newDatabase, sct.sc.cfg.Database) + } + + query := sct.mustQuery("SELECT 1", nil) + query.Close() + + if !strings.EqualFold(sct.sc.cfg.Database, newDatabase) { + t.Errorf("database should be switched, expected %v, got %v", newDatabase, sct.sc.cfg.Database) + } + }) + t.Run("Warehouse", func(t *testing.T) { + newWarehouse := fmt.Sprintf("test_warehouse_%v", runID) + if strings.EqualFold(sct.sc.cfg.Warehouse, newWarehouse) { + t.Errorf("warehouse should not be switched") + } + + // TODO replace with mustExec + sct.mustExec(fmt.Sprintf("CREATE WAREHOUSE %v", newWarehouse), nil) + defer sct.mustExec(fmt.Sprintf("DROP WAREHOUSE %v", newWarehouse), nil) + + if !strings.EqualFold(sct.sc.cfg.Warehouse, newWarehouse) { + t.Errorf("warehouse should be switched, expected %v, got %v", newWarehouse, sct.sc.cfg.Warehouse) + } + + query := sct.mustQuery("SELECT 1", nil) + query.Close() + + if !strings.EqualFold(sct.sc.cfg.Warehouse, newWarehouse) { + t.Errorf("warehouse should be switched, expected %v, got %v", newWarehouse, sct.sc.cfg.Warehouse) + } + }) + t.Run("Role", func(t *testing.T) { + if strings.EqualFold(sct.sc.cfg.Role, "PUBLIC") { + t.Errorf("role should not be public for this test") + } + + sct.mustExec("USE ROLE public", nil) + + if !strings.EqualFold(sct.sc.cfg.Role, "PUBLIC") { + t.Errorf("role should be switched, expected public, got %v", sct.sc.cfg.Warehouse) + } + + query := sct.mustQuery("SELECT 1", nil) + query.Close() + + if !strings.EqualFold(sct.sc.cfg.Role, "PUBLIC") { + t.Errorf("role should be switched, expected public, got %v", sct.sc.cfg.Warehouse) + } + }) + t.Run("Session param - DATE_OUTPUT_FORMAT", func(t *testing.T) { + if !strings.EqualFold(*sct.sc.cfg.Params["date_output_format"], "YYYY-MM-DD") { + t.Errorf("should use default date_output_format, but got: %v", *sct.sc.cfg.Params["date_output_format"]) + } + + // TODO replace with mustExec + sct.mustExec("ALTER SESSION SET DATE_OUTPUT_FORMAT = 'DD-MM-YYYY'", nil) + defer sct.mustExec("ALTER SESSION SET DATE_OUTPUT_FORMAT = 'YYYY-MM-DD'", nil) + + if !strings.EqualFold(*sct.sc.cfg.Params["date_output_format"], "DD-MM-YYYY") { + t.Errorf("role should be switched, expected public, got %v", sct.sc.cfg.Warehouse) + } + + query := sct.mustQuery("SELECT 1", nil) + query.Close() + + if !strings.EqualFold(*sct.sc.cfg.Params["date_output_format"], "DD-MM-YYYY") { + t.Errorf("role should be switched, expected public, got %v", sct.sc.cfg.Warehouse) + } + }) + }) + }) + } +} + +func TestConnIsCleanAfterClose(t *testing.T) { + // We create a new db here to not use the default pool as we can leave it in dirty state. + t.Skip("Fails, because connection is returned to a pool dirty") + ctx := context.Background() + runID := time.Now().UnixMilli() + + db := openDB(t) + defer db.Close() + db.SetMaxOpenConns(1) + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + dbt := DBTest{t, conn} + + dbt.mustExec(forceJSON) + + var dbName string + rows1 := dbt.mustQuery("SELECT CURRENT_DATABASE()") + rows1.Next() + rows1.Scan(&dbName) + + newDbName := fmt.Sprintf("test_database_%v", runID) + dbt.mustExec("CREATE DATABASE " + newDbName) + + rows1.Close() + conn.Close() + + conn2, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + dbt2 := DBTest{t, conn2} + + var dbName2 string + rows2 := dbt2.mustQuery("SELECT CURRENT_DATABASE()") + defer rows2.Close() + rows2.Next() + rows2.Scan(&dbName2) + + if !strings.EqualFold(dbName, dbName2) { + t.Errorf("fresh connection from pool should have original database") + } +} diff --git a/statement_test.go b/statement_test.go index 68f1e7a7b..963e0a61c 100644 --- a/statement_test.go +++ b/statement_test.go @@ -36,6 +36,9 @@ func openConn(t *testing.T) *sql.Conn { if conn, err = db.Conn(context.Background()); err != nil { t.Fatalf("failed to open connection: %v", err) } + + conn.ExecContext(context.Background(), "ALTER SESSION SET ENABLE_SNOW_654741_FOR_TESTING = true") + return conn }