diff --git a/CHANGELOG.md b/CHANGELOG.md index 98feea9cdaa..a3da40594f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ - Added support for `Series.str.center`. - Added support for `Series.str.pad`. +#### Bug Fixes + +- Fixed a bug that `DataFrame.show` incorrectly fetch all data of the dataframe. ## 1.26.0 (2024-12-05) diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index f206b0129b3..5bbcf3e2b44 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -524,7 +524,9 @@ def run_query( if ignore_results: return {"data": None, "sfqid": results_cursor.sfqid} return self._to_data_or_iter( - results_cursor=results_cursor, to_pandas=to_pandas, to_iter=to_iter + results_cursor=results_cursor, + to_pandas=to_pandas, + to_iter=to_iter, ) else: return AsyncJob( @@ -766,10 +768,13 @@ def get_result_set( return result, result_meta def get_result_and_metadata( - self, plan: SnowflakePlan, **kwargs + self, plan: SnowflakePlan, limit: Optional[int] = None, **kwargs ) -> Tuple[List[Row], List[Attribute]]: - result_set, result_meta = self.get_result_set(plan, **kwargs) - result = result_set_to_rows(result_set["data"]) + if limit is not None: + result_set, result_meta = self.get_result_set(plan, to_iter=True, **kwargs) + else: + result_set, result_meta = self.get_result_set(plan, **kwargs) + result = result_set_to_rows(result_set["data"], limit=limit) attributes = convert_result_meta_to_attribute(result_meta, self.max_string_size) return result, attributes diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index bd27e46f2fd..156de1d000c 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -702,6 +702,7 @@ def result_set_to_rows( result_set: List[Any], result_meta: Optional[Union[List[ResultMetadata], List["ResultMetadataV2"]]] = None, case_sensitive: bool = True, + limit: Optional[int] = None, ) -> List[Row]: col_names = [col.name for col in result_meta] if result_meta else None rows = [] @@ -710,7 +711,9 @@ def result_set_to_rows( row_struct = ( Row._builder.build(*col_names).set_case_sensitive(case_sensitive).to_row() ) - for data in result_set: + for i, data in enumerate(result_set): + if limit is not None and i >= limit: + break if data is None: raise ValueError("Result returned from Python connector is None") row = row_struct(*data) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index c4dd09095fd..7bd7df2445f 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -4368,7 +4368,7 @@ def _show_string( ) else: res, meta = self._session._conn.get_result_and_metadata( - self._plan, **kwargs + self._plan, limit=n, **kwargs ) result = res[:n] diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index c4ef34bc27b..8bff029b014 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -755,7 +755,7 @@ def get_result_set( ) def get_result_and_metadata( - self, plan: SnowflakePlan, **kwargs + self, plan: SnowflakePlan, limit: Optional[int] = None, **kwargs ) -> Tuple[List[Row], List[Attribute]]: res = execute_mock_plan(plan, plan.expr_to_alias) attrs = [ @@ -772,6 +772,8 @@ def get_result_and_metadata( rows = [] for i in range(len(res)): + if limit is not None and i >= limit: + break values = [] for j, attr in enumerate(attrs): value = res.iloc[i, j] diff --git a/tests/integ/scala/test_dataframe_suite.py b/tests/integ/scala/test_dataframe_suite.py index 33390128924..8e0bb75a302 100644 --- a/tests/integ/scala/test_dataframe_suite.py +++ b/tests/integ/scala/test_dataframe_suite.py @@ -10,6 +10,7 @@ from decimal import Decimal from logging import getLogger from typing import Iterator +from unittest import mock import pytest @@ -300,6 +301,50 @@ def test_show(session): ) +def test_show_non_select_statement(session): + df = session.create_dataframe([[1, 2, 3, 4] for _ in range(100)]).to_df( + ['"col1"', "col2_a", "col2_b", "col3"] + ) + with mock.patch( + "snowflake.snowpark.dataframe.is_sql_select_statement", return_value=False + ): + res = df._show_string(5, _emit_ast=session.ast_enabled) + assert ( + res + == """ +----------------------------------------- +|"col1" |"COL2_A" |"COL2_B" |"COL3" | +----------------------------------------- +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +-----------------------------------------\n""".lstrip() + ) + + # test show with default value + res = df._show_string(_emit_ast=session.ast_enabled) + assert ( + res + == """ +----------------------------------------- +|"col1" |"COL2_A" |"COL2_B" |"COL3" | +----------------------------------------- +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +|1 |2 |3 |4 | +-----------------------------------------\n""".lstrip() + ) + + def test_cache_result(session): table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) session.create_dataframe([[1], [2]], schema=["num"]).write.save_as_table(table_name)