Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-911327: Use CTAS for save as table #1075

Merged
merged 3 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bug Fixes

- Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs.
- Revert back to using CTAS (create table as select) statement for `Dataframe.writer.save_as_table` which does not need insert permission for writing tables.

## 1.8.0 (2023-09-14)

Expand Down
10 changes: 9 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,14 +715,22 @@ def batch_insert_into_statement(table_name: str, column_names: List[str]) -> str
def create_table_as_select_statement(
table_name: str,
child: str,
column_definition: str,
replace: bool = False,
error: bool = True,
table_type: str = EMPTY_STRING,
clustering_key: Optional[Iterable[str]] = None,
) -> str:
cluster_by_clause = (
(CLUSTER_BY + LEFT_PARENTHESIS + COMMA.join(clustering_key) + RIGHT_PARENTHESIS)
if clustering_key
else EMPTY_STRING
)
return (
f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING} {table_type.upper()} {TABLE}"
f"{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING}"
f" {table_name}{AS}{project_statement([], child)}"
f" {table_name}{LEFT_PARENTHESIS}{column_definition}{RIGHT_PARENTHESIS}"
f"{cluster_by_clause} {AS}{project_statement([], child)}"
)


Expand Down
50 changes: 37 additions & 13 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,12 @@ def save_as_table(
child: SnowflakePlan,
) -> SnowflakePlan:
full_table_name = ".".join(table_name)
column_definition = attribute_to_schema_string(child.attributes)

def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
create_table = create_table_statement(
full_table_name,
attribute_to_schema_string(child.attributes),
column_definition,
replace=replace,
error=error,
table_type=table_type,
Expand Down Expand Up @@ -612,20 +613,43 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace=False, error=True):
else:
return get_create_and_insert_plan(child, replace=False, error=False)
elif mode == SaveMode.OVERWRITE:
return get_create_and_insert_plan(child, replace=True)
return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
x,
column_definition,
replace=True,
table_type=table_type,
clustering_key=clustering_keys,
),
child,
None,
)
elif mode == SaveMode.IGNORE:
if self.session._table_exists(table_name):
return self.build(
lambda x: create_table_as_select_statement(
full_table_name, x, error=False, table_type=table_type
),
child,
None,
)
else:
return get_create_and_insert_plan(child, replace=False, error=False)
return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
x,
column_definition,
error=False,
table_type=table_type,
clustering_key=clustering_keys,
),
child,
None,
)
elif mode == SaveMode.ERROR_IF_EXISTS:
return get_create_and_insert_plan(child, replace=False, error=True)
return self.build(
lambda x: create_table_as_select_statement(
full_table_name,
x,
column_definition,
table_type=table_type,
clustering_key=clustering_keys,
),
child,
None,
)

def limit(
self,
Expand Down
24 changes: 8 additions & 16 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,11 +2306,10 @@ def test_table_types_in_save_as_table(session, save_mode, table_type):
Utils.drop_table(session, table_name)


@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
@pytest.mark.parametrize(
"save_mode", ["append", "overwrite", "ignore", "errorifexists"]
)
def test_save_as_table_respects_schema(session, save_mode, table_type):
Comment on lines -2309 to -2313
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing excessive table_type parametrization for tests which don't provide any additional value in the test. I do want to listen to team's opinion on this

def test_save_as_table_respects_schema(session, save_mode):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)

schema1 = StructType(
Expand All @@ -2325,32 +2324,29 @@ def test_save_as_table_respects_schema(session, save_mode, table_type):
df2 = session.create_dataframe([(1), (2)], schema=schema2)

try:
df1.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df1.write.save_as_table(table_name, mode=save_mode)
saved_df = session.table(table_name)
Utils.is_schema_same(saved_df.schema, schema1)

if save_mode == "overwrite":
df2.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df2.write.save_as_table(table_name, mode=save_mode)
saved_df = session.table(table_name)
Utils.is_schema_same(saved_df.schema, schema2)
elif save_mode == "ignore":
df2.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df2.write.save_as_table(table_name, mode=save_mode)
saved_df = session.table(table_name)
Utils.is_schema_same(saved_df.schema, schema1)
else: # save_mode in ('append', 'errorifexists')
with pytest.raises(SnowparkSQLException):
df2.write.save_as_table(
table_name, mode=save_mode, table_type=table_type
)
df2.write.save_as_table(table_name, mode=save_mode)
finally:
Utils.drop_table(session, table_name)


@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
@pytest.mark.parametrize(
"save_mode", ["append", "overwrite", "ignore", "errorifexists"]
)
def test_save_as_table_nullable_test(session, save_mode, table_type):
def test_save_as_table_nullable_test(session, save_mode):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
schema = StructType(
[
Expand All @@ -2365,7 +2361,7 @@ def test_save_as_table_nullable_test(session, save_mode, table_type):
(IntegrityError, SnowparkSQLException),
match="NULL result in a non-nullable column",
):
df.write.save_as_table(table_name, mode=save_mode, table_type=table_type)
df.write.save_as_table(table_name, mode=save_mode)
finally:
Utils.drop_table(session, table_name)

Expand Down Expand Up @@ -2397,9 +2393,8 @@ def test_save_as_table_with_table_sproc_output(session, save_mode, table_type):
Utils.drop_procedure(session, f"{temp_sp_name}()")


@pytest.mark.parametrize("table_type", ["", "temp", "temporary", "transient"])
@pytest.mark.parametrize("save_mode", ["append", "overwrite"])
def test_write_table_with_clustering_keys(session, save_mode, table_type):
def test_write_table_with_clustering_keys(session, save_mode):
table_name1 = Utils.random_name_for_temp_object(TempObjectType.TABLE)
table_name2 = Utils.random_name_for_temp_object(TempObjectType.TABLE)
table_name3 = Utils.random_name_for_temp_object(TempObjectType.TABLE)
Expand Down Expand Up @@ -2433,7 +2428,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type):
df1.write.save_as_table(
table_name1,
mode=save_mode,
table_type=table_type,
clustering_keys=["c1", "c2"],
)
ddl = session._run_query(f"select get_ddl('table', '{table_name1}')")[0][0]
Expand All @@ -2442,7 +2436,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type):
df2.write.save_as_table(
table_name2,
mode=save_mode,
table_type=table_type,
clustering_keys=[
col("c1").cast(DateType()),
col("c2").substring(0, 10),
Expand All @@ -2454,7 +2447,6 @@ def test_write_table_with_clustering_keys(session, save_mode, table_type):
df3.write.save_as_table(
table_name3,
mode=save_mode,
table_type=table_type,
clustering_keys=[get_path(col("v"), lit("Data.id")).cast(IntegerType())],
)
ddl = session._run_query(f"select get_ddl('table', '{table_name3}')")[0][0]
Expand Down
Loading