diff --git a/CHANGELOG.md b/CHANGELOG.md index a67fbd35adc..07eca3d1aae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ ### Bug Fixes - Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs. +- Revert back to using CTAS (create table as select) statement for `Dataframe.writer.save_as_table` which does not need insert permission for writing tables. ## 1.8.0 (2023-09-14) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 07bb4e679e3..c8909391625 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -715,14 +715,22 @@ def batch_insert_into_statement(table_name: str, column_names: List[str]) -> str def create_table_as_select_statement( table_name: str, child: str, + column_definition: str, replace: bool = False, error: bool = True, table_type: str = EMPTY_STRING, + clustering_key: Optional[Iterable[str]] = None, ) -> str: + cluster_by_clause = ( + (CLUSTER_BY + LEFT_PARENTHESIS + COMMA.join(clustering_key) + RIGHT_PARENTHESIS) + if clustering_key + else EMPTY_STRING + ) return ( f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING} {table_type.upper()} {TABLE}" f"{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING}" - f" {table_name}{AS}{project_statement([], child)}" + f" {table_name}{LEFT_PARENTHESIS}{column_definition}{RIGHT_PARENTHESIS}" + f"{cluster_by_clause} {AS}{project_statement([], child)}" ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index c35214a38a7..9cd1f5c632b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -563,11 +563,12 @@ def save_as_table( child: SnowflakePlan, ) -> SnowflakePlan: full_table_name = ".".join(table_name) + column_definition = attribute_to_schema_string(child.attributes) def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): create_table = create_table_statement( full_table_name, - attribute_to_schema_string(child.attributes), + column_definition, replace=replace, error=error, table_type=table_type, @@ -612,20 +613,43 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True): else: return get_create_and_insert_plan(child, replace=False, error=False) elif mode == SaveMode.OVERWRITE: - return get_create_and_insert_plan(child, replace=True) + return self.build( + lambda x: create_table_as_select_statement( + full_table_name, + x, + column_definition, + replace=True, + table_type=table_type, + clustering_key=clustering_keys, + ), + child, + None, + ) elif mode == SaveMode.IGNORE: - if self.session._table_exists(table_name): - return self.build( - lambda x: create_table_as_select_statement( - full_table_name, x, error=False, table_type=table_type - ), - child, - None, - ) - else: - return get_create_and_insert_plan(child, replace=False, error=False) + return self.build( + lambda x: create_table_as_select_statement( + full_table_name, + x, + column_definition, + error=False, + table_type=table_type, + clustering_key=clustering_keys, + ), + child, + None, + ) elif mode == SaveMode.ERROR_IF_EXISTS: - return get_create_and_insert_plan(child, replace=False, error=True) + return self.build( + lambda x: create_table_as_select_statement( + full_table_name, + x, + column_definition, + table_type=table_type, + clustering_key=clustering_keys, + ), + child, + None, + ) def limit( self, diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index c40ba4176d3..7c098dac3f9 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -2306,11 +2306,10 @@ def test_table_types_in_save_as_table(session, save_mode, table_type): Utils.drop_table(session, table_name) -@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"]) @pytest.mark.parametrize( "save_mode", ["append", "overwrite", "ignore", "errorifexists"] ) -def test_save_as_table_respects_schema(session, save_mode, table_type): +def test_save_as_table_respects_schema(session, save_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) schema1 = StructType( @@ -2325,32 +2324,29 @@ def test_save_as_table_respects_schema(session, save_mode, table_type): df2 = session.create_dataframe([(1), (2)], schema=schema2) try: - df1.write.save_as_table(table_name, mode=save_mode, table_type=table_type) + df1.write.save_as_table(table_name, mode=save_mode) saved_df = session.table(table_name) Utils.is_schema_same(saved_df.schema, schema1) if save_mode == "overwrite": - df2.write.save_as_table(table_name, mode=save_mode, table_type=table_type) + df2.write.save_as_table(table_name, mode=save_mode) saved_df = session.table(table_name) Utils.is_schema_same(saved_df.schema, schema2) elif save_mode == "ignore": - df2.write.save_as_table(table_name, mode=save_mode, table_type=table_type) + df2.write.save_as_table(table_name, mode=save_mode) saved_df = session.table(table_name) Utils.is_schema_same(saved_df.schema, schema1) else: # save_mode in ('append', 'errorifexists') with pytest.raises(SnowparkSQLException): - df2.write.save_as_table( - table_name, mode=save_mode, table_type=table_type - ) + df2.write.save_as_table(table_name, mode=save_mode) finally: Utils.drop_table(session, table_name) -@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"]) @pytest.mark.parametrize( "save_mode", ["append", "overwrite", "ignore", "errorifexists"] ) -def test_save_as_table_nullable_test(session, save_mode, table_type): +def test_save_as_table_nullable_test(session, save_mode): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) schema = StructType( [ @@ -2365,7 +2361,7 @@ def test_save_as_table_nullable_test(session, save_mode, table_type): (IntegrityError, SnowparkSQLException), match="NULL result in a non-nullable column", ): - df.write.save_as_table(table_name, mode=save_mode, table_type=table_type) + df.write.save_as_table(table_name, mode=save_mode) finally: Utils.drop_table(session, table_name) @@ -2397,9 +2393,8 @@ def test_save_as_table_with_table_sproc_output(session, save_mode, table_type): Utils.drop_procedure(session, f"{temp_sp_name}()") -@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"]) @pytest.mark.parametrize("save_mode", ["append", "overwrite"]) -def test_write_table_with_clustering_keys(session, save_mode, table_type): +def test_write_table_with_clustering_keys(session, save_mode): table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE) table_name3 = Utils.random_name_for_temp_object(TempObjectType.TABLE) @@ -2433,7 +2428,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type): df1.write.save_as_table( table_name1, mode=save_mode, - table_type=table_type, clustering_keys=["c1", "c2"], ) ddl = session._run_query(f"select get_ddl('table', '{table_name1}')")[0][0] @@ -2442,7 +2436,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type): df2.write.save_as_table( table_name2, mode=save_mode, - table_type=table_type, clustering_keys=[ col("c1").cast(DateType()), col("c2").substring(0, 10), @@ -2454,7 +2447,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type): df3.write.save_as_table( table_name3, mode=save_mode, - table_type=table_type, clustering_keys=[get_path(col("v"), lit("Data.id")).cast(IntegerType())], ) ddl = session._run_query(f"select get_ddl('table', '{table_name3}')")[0][0]