Skip to content

Commit

Permalink
SNOW-1615695: Push empty all/any processing down to query compiler (#…
Browse files Browse the repository at this point in the history
…2039)

Co-authored-by: Varnika Budati <[email protected]>
  • Loading branch information
sfc-gh-joshi and sfc-gh-vbudati authored Aug 9, 2024
1 parent 3459a4f commit 958ebf4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
8 changes: 4 additions & 4 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def all(self, axis=0, bool_only=None, skipna=True, **kwargs):
return data_for_compute.all(
axis=axis, bool_only=False, skipna=skipna, **kwargs
)
result = self._reduce_dimension(
return self._reduce_dimension(
self._query_compiler.all(
axis=axis, bool_only=bool_only, skipna=skipna, **kwargs
)
Expand All @@ -956,7 +956,7 @@ def all(self, axis=0, bool_only=None, skipna=True, **kwargs):
return result.all(
axis=axis, bool_only=bool_only, skipna=skipna, **kwargs
)
return True if result is None else result
return result

def any(self, axis=0, bool_only=None, skipna=True, **kwargs):
"""
Expand All @@ -977,7 +977,7 @@ def any(self, axis=0, bool_only=None, skipna=True, **kwargs):
return data_for_compute.any(
axis=axis, bool_only=False, skipna=skipna, **kwargs
)
result = self._reduce_dimension(
return self._reduce_dimension(
self._query_compiler.any(
axis=axis, bool_only=bool_only, skipna=skipna, **kwargs
)
Expand All @@ -998,7 +998,7 @@ def any(self, axis=0, bool_only=None, skipna=True, **kwargs):
return result.any(
axis=axis, bool_only=bool_only, skipna=skipna, **kwargs
)
return False if result is None else result
return result

def apply(
self,
Expand Down
21 changes: 15 additions & 6 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,21 @@ def get_snowflake_agg_func(
# through the aggregate frontend in this manner is unsupported.
return None
return lambda col: column_quantile(col, interpolation, q)
elif agg_func in ("all", "any"):
# If there are no rows in the input frame, the function will also return NULL, which should
# instead by TRUE for "all" and FALSE for "any".
# Need to wrap column name in IDENTIFIER, or else the agg function will treat the name
# as a string literal.
# The generated SQL expression for "all" is
# IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE)
# The expression for "any" is
# IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE)
default_value = bool(agg_func == "all")
return lambda col: builtin("ifnull")(
# mypy refuses to acknowledge snowflake_agg_func is non-NULL here
snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc]
pandas_lit(default_value),
)
else:
snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func)

Expand Down Expand Up @@ -699,12 +714,6 @@ def generate_aggregation_column(
agg_snowpark_column = coalesce(
snowflake_agg_func(snowpark_column), pandas_lit(0)
)
elif snowflake_agg_func in (
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["all"],
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["any"],
):
# Need to wrap column name in IDENTIFIER, or else bool agg function will treat the name as a string literal
agg_snowpark_column = snowflake_agg_func(builtin("identifier")(snowpark_column))
elif snowflake_agg_func == array_agg:
# Array aggregation requires the ordering columns, which we have to
# pass in here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2130,22 +2130,21 @@ def _bool_reduce_helper(
)
else:
assert axis == 0
# The query compiler agg method complains if the resulting aggregation is empty, so we add a special check here
# The query compiler agg method complains if the resulting aggregation is empty, so we add a special check here.
if empty_columns:
# The result should be an empty series of dtype bool, which is internally represented as an
# empty dataframe with only the MODIN_UNNAMED_SERIES_LABEL column
return SnowflakeQueryCompiler.from_pandas(
native_pd.DataFrame({MODIN_UNNAMED_SERIES_LABEL: []}, dtype=bool)
)

# The resulting DF is transposed so will have string 'NULL' as a column name,
# so we need to manually remove it
# If there are now rows (but there are columns), booland_agg/boolor_agg would return NULL.
# This behavior is handled within aggregation_utils to avoid an extra query.
return self.agg(
agg_func,
axis=0,
args=[],
kwargs={"skipna": skipna},
).set_columns([MODIN_UNNAMED_SERIES_LABEL])
)

def all(
self,
Expand Down

0 comments on commit 958ebf4

Please sign in to comment.