diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index 8ba63bc1ad4..50aee2a3360 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -420,10 +420,21 @@ def ensure_row_count_column(self) -> "OrderedDataFrame": wrap_double_underscore=True, )[0] ) - ordered_dataframe = self.select( - *self.projected_column_snowflake_quoted_identifiers, - count("*").over().as_(row_count_snowflake_quoted_identifier), - ) + if not self.is_projection_of_table(): + ordered_dataframe = self.select( + *self.projected_column_snowflake_quoted_identifiers, + count("*").over().as_(row_count_snowflake_quoted_identifier), + ) + else: + from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit + + row_count = self.select(count("*").as_("__count_of_rows__"),).collect()[ + 0 + ][0] + ordered_dataframe = self.select( + *self.projected_column_snowflake_quoted_identifiers, + pandas_lit(row_count).as_(row_count_snowflake_quoted_identifier), + ) # inplace update so dataframe_ref can be shared. Note that we keep # the original ordering columns. ordered_dataframe.row_count_snowflake_quoted_identifier = ( @@ -2019,3 +2030,23 @@ def sample(self, n: Optional[int], frac: Optional[float]) -> "OrderedDataFrame": ordering_columns=self.ordering_columns, ) ) + + def is_projection_of_table(self) -> bool: + """ + Return whether or not the current OrderedDataFrame is simply a projection of a table. + + Returns: + bool + True if the current OrderedDataFrame is simply a projection of a table. False if it represents + a more complex operation. + """ + # If we have only performed projections since creating this DataFrame, it will only contain + # 1 API call in the plan - either `Session.sql` for DataFrames based off of I/O operations + # e.g. `read_snowflake` or `read_csv`, or `Session.create_dataframe` for DataFrames created + # out of Python objects. + snowpark_df = self._dataframe_ref.snowpark_dataframe + snowpark_plan = snowpark_df._plan + return len(snowpark_plan.api_calls) == 1 and any( + accepted_api in snowpark_plan.api_calls[0]["name"] + for accepted_api in ["Session.sql", "Session.create_dataframe"] + )