From 958ebf4a6eb7156c97c5b38ad16861a5d88947b1 Mon Sep 17 00:00:00 2001 From: Jonathan Shi <149419494+sfc-gh-joshi@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:10:26 -0700 Subject: [PATCH] SNOW-1615695: Push empty all/any processing down to query compiler (#2039) Co-authored-by: Varnika Budati --- src/snowflake/snowpark/modin/pandas/base.py | 8 +++---- .../plugin/_internal/aggregation_utils.py | 21 +++++++++++++------ .../compiler/snowflake_query_compiler.py | 9 ++++---- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/modin/pandas/base.py b/src/snowflake/snowpark/modin/pandas/base.py index 7be5ff60d91..43ed5ee389f 100644 --- a/src/snowflake/snowpark/modin/pandas/base.py +++ b/src/snowflake/snowpark/modin/pandas/base.py @@ -933,7 +933,7 @@ def all(self, axis=0, bool_only=None, skipna=True, **kwargs): return data_for_compute.all( axis=axis, bool_only=False, skipna=skipna, **kwargs ) - result = self._reduce_dimension( + return self._reduce_dimension( self._query_compiler.all( axis=axis, bool_only=bool_only, skipna=skipna, **kwargs ) @@ -956,7 +956,7 @@ def all(self, axis=0, bool_only=None, skipna=True, **kwargs): return result.all( axis=axis, bool_only=bool_only, skipna=skipna, **kwargs ) - return True if result is None else result + return result def any(self, axis=0, bool_only=None, skipna=True, **kwargs): """ @@ -977,7 +977,7 @@ def any(self, axis=0, bool_only=None, skipna=True, **kwargs): return data_for_compute.any( axis=axis, bool_only=False, skipna=skipna, **kwargs ) - result = self._reduce_dimension( + return self._reduce_dimension( self._query_compiler.any( axis=axis, bool_only=bool_only, skipna=skipna, **kwargs ) @@ -998,7 +998,7 @@ def any(self, axis=0, bool_only=None, skipna=True, **kwargs): return result.any( axis=axis, bool_only=bool_only, skipna=skipna, **kwargs ) - return False if result is None else result + return result def apply( self, diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 7e2693e743e..9b88b286d40 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -478,6 +478,21 @@ def get_snowflake_agg_func( # through the aggregate frontend in this manner is unsupported. return None return lambda col: column_quantile(col, interpolation, q) + elif agg_func in ("all", "any"): + # If there are no rows in the input frame, the function will also return NULL, which should + # instead by TRUE for "all" and FALSE for "any". + # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name + # as a string literal. + # The generated SQL expression for "all" is + # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE) + # The expression for "any" is + # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE) + default_value = bool(agg_func == "all") + return lambda col: builtin("ifnull")( + # mypy refuses to acknowledge snowflake_agg_func is non-NULL here + snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc] + pandas_lit(default_value), + ) else: snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) @@ -699,12 +714,6 @@ def generate_aggregation_column( agg_snowpark_column = coalesce( snowflake_agg_func(snowpark_column), pandas_lit(0) ) - elif snowflake_agg_func in ( - SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["all"], - SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["any"], - ): - # Need to wrap column name in IDENTIFIER, or else bool agg function will treat the name as a string literal - agg_snowpark_column = snowflake_agg_func(builtin("identifier")(snowpark_column)) elif snowflake_agg_func == array_agg: # Array aggregation requires the ordering columns, which we have to # pass in here. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index da641b8df24..22ac143dfe2 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -2130,22 +2130,21 @@ def _bool_reduce_helper( ) else: assert axis == 0 - # The query compiler agg method complains if the resulting aggregation is empty, so we add a special check here + # The query compiler agg method complains if the resulting aggregation is empty, so we add a special check here. if empty_columns: # The result should be an empty series of dtype bool, which is internally represented as an # empty dataframe with only the MODIN_UNNAMED_SERIES_LABEL column return SnowflakeQueryCompiler.from_pandas( native_pd.DataFrame({MODIN_UNNAMED_SERIES_LABEL: []}, dtype=bool) ) - - # The resulting DF is transposed so will have string 'NULL' as a column name, - # so we need to manually remove it + # If there are now rows (but there are columns), booland_agg/boolor_agg would return NULL. + # This behavior is handled within aggregation_utils to avoid an extra query. return self.agg( agg_func, axis=0, args=[], kwargs={"skipna": skipna}, - ).set_columns([MODIN_UNNAMED_SERIES_LABEL]) + ) def all( self,