From 28669989c04617e5eeff3b6fc2c63a8016cf8c92 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 13:29:07 -0700 Subject: [PATCH 01/22] SNOW-1659512: Fix literal complexity calculation (#2265) --- .../snowpark/_internal/analyzer/analyzer.py | 5 +---- .../_internal/analyzer/snowflake_plan_node.py | 19 ++++++++++++++++++- tests/integ/test_query_plan_analysis.py | 18 ++++++++++++++++++ 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 76e91b7da92..d8622299ea9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if ( - len(logical_plan.output) * len(logical_plan.data) - < ARRAY_BIND_THRESHOLD - ): + if not logical_plan.is_large_local_data: return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index e3e032cd94b..aa8730dcf7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -144,10 +144,27 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def is_large_local_data(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.is_large_local_data: + # When the number of literals exceeds the threshold, we generate 3 queries: + # 1. create table query + # 2. insert into table query + # 3. select * from table query + # We only consider the complexity from the final select * query since other queries + # are built based on it. + return { + PlanNodeCategory.COLUMN: 1, + } + + # If we stay under the threshold, we generate a single query: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) - # TODO: use ARRAY_BIND_THRESHOLD return { PlanNodeCategory.COLUMN: len(self.output), PlanNodeCategory.LITERAL: len(self.data) * len(self.output), diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 0e8bb0d902d..81b852c46c1 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -98,6 +98,24 @@ def test_range_statement(session: Session): ) +def test_literal_complexity_for_snowflake_values(session: Session): + from snowflake.snowpark._internal.analyzer import analyzer + + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + assert_df_subtree_query_complexity( + df1, {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LITERAL: 4} + ) + + try: + original_threshold = analyzer.ARRAY_BIND_THRESHOLD + analyzer.ARRAY_BIND_THRESHOLD = 2 + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + # SELECT "A", "B" from (SELECT * FROM TEMP_TABLE) + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 3}) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + + def test_generator_table_function(session: Session): df1 = session.generator( seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 From f2f4b3366b1f8319556354b6e008044511ed0961 Mon Sep 17 00:00:00 2001 From: Mahesh Vashishtha Date: Thu, 12 Sep 2024 15:14:28 -0700 Subject: [PATCH 02/22] SNOW-1654730: Refactor aggregation_utils to reduce duplication and clarify interfaces. (#2270) Fixes SNOW-1654730 # Changes to the interface of aggregation_utils - `get_snowflake_agg_func` now returns a `NamedTuple`, `SnowflakeAggFunc`, which contains a snowpark aggregation method, along with the bool `preserves_snowpark_pandas_types`. Formerly, `get_snowflake_agg_func` would return a Snowpark aggregation and the caller would then use separate dictionaries to deduce how the aggregation affects Snowpark pandas types. - Prepend an underscore to the names of several objects in aggregation_utils that are only used internally, e.g. `_array_agg_keepna` and `_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION` - Formerly, to get axis=1 aggregation, query compiler methods would call `generate_rowwise_aggregation_function`. Instead, make `get_snowflake_agg_func` the common interface to get `SnowflakeAggFunc` even for axis=1, and make `generate_rowwise_aggregation_function` an internal method called `_generate_rowwise_aggregation_function`. # Changes to the internals of aggregation_utils - Before this commit, there were 5 different maps describing how to translate pandas aggregations to Snowpark: `SNOWFLAKE_BUILTIN_AGG_FUNC_MAP` mapped from Snowpark pandas aggregations to the axis=0 aggregation; `GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE` and `GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES` told whether the axis=0 aggregations would preserve Snowpark pandas types; `SNOWFLAKE_COLUMNS_AGG_FUNC_MAP` would tell how to aggregate on axis=1 when skipna=True; and `SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP` would tell how to aggregate on axis=1 when skipna=False. All of these maps repeated the mapping of pairs like `"sum"` and `np.sum` to the same aggregation function. In this commit, keep a single mapping, `_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION`, from pandas aggregations to instances of the internal tuple `_SnowparkPandasAggregation`. `_SnowparkPandasAggregation` includes `preserves_snowpark_pandas_type`, as well as optionally the aggregation functions for `axis=0`; `axis=1, skipna=False`; and `axis=1, skipna=True`. # New feature - As a consequence of the refactoring, `groupby().var()` no longer raises `NotImplementedError`, but it's invalid in pandas, so we correctly raise `TypeError`. --------- Signed-off-by: sfc-gh-mvashishtha --- .../plugin/_internal/aggregation_utils.py | 458 ++++++++++++------ .../modin/plugin/_internal/pivot_utils.py | 19 +- .../compiler/snowflake_query_compiler.py | 61 +-- .../plugin/extensions/timedelta_index.py | 10 +- .../modin/groupby/test_groupby_basic_agg.py | 21 + .../modin/groupby/test_groupby_negative.py | 25 +- tests/unit/modin/test_aggregation_utils.py | 51 +- 7 files changed, 422 insertions(+), 223 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 01ccad8f430..3d25b1273b5 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Iterable from functools import partial from inspect import getmembers -from types import BuiltinFunctionType +from types import BuiltinFunctionType, MappingProxyType from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union import numpy as np @@ -65,6 +65,9 @@ OrderedDataFrame, OrderingColumn, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.utils import ( from_pandas_label, pandas_lit, @@ -85,7 +88,7 @@ } -def array_agg_keepna( +def _array_agg_keepna( column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn] ) -> Column: """ @@ -239,62 +242,63 @@ def _columns_coalescing_idxmax_idxmin_helper( ) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding snowflake builtin aggregation function for axis=0. If any change -# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and -# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly. -SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": count, - "mean": mean, - "min": min_, - "max": max_, - "idxmax": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmax" - ), - "idxmin": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmin" - ), - "sum": sum_, - "median": median, - "skew": skew, - "std": stddev, - "var": variance, - "all": builtin("booland_agg"), - "any": builtin("boolor_agg"), - np.max: max_, - np.min: min_, - np.sum: sum_, - np.mean: mean, - np.median: median, - np.std: stddev, - np.var: variance, - "array_agg": array_agg, - "quantile": column_quantile, - "nunique": count_distinct, -} -GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = ( - "min", - "max", - "sum", - "mean", - "median", - "std", - np.max, - np.min, - np.sum, - np.mean, - np.median, - np.std, -) -GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = ( - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", -) +class _SnowparkPandasAggregation(NamedTuple): + """ + A representation of a Snowpark pandas aggregation. + + This structure gives us a common representation for an aggregation that may + have multiple aliases, like "sum" and np.sum. + """ + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool + + # This callable takes a single Snowpark column as input and aggregates the + # column on axis=0. If None, Snowpark pandas does not support this + # aggregation on axis=0. + axis_0_aggregation: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=True, i.e. not including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=True. + axis_1_aggregation_skipna: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=False, i.e. including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=False. + axis_1_aggregation_keepna: Optional[Callable] = None + + +class SnowflakeAggFunc(NamedTuple): + """ + A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType. + """ + + # The aggregation function in Snowpark. + # For aggregation on axis=0, this field should take a single Snowpark + # column and return the aggregated column. + # For aggregation on axis=1, this field should take an arbitrary number + # of Snowpark columns and return the aggregated column. + snowpark_aggregation: Callable + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool class AggFuncWithLabel(NamedTuple): @@ -413,35 +417,143 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: return sum(builtin("zeroifnull")(col) for col in cols) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation -# function may either be a builtin aggregation function, or a function taking in *arg columns -# that then calls the appropriate builtin aggregations. -SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": _columns_count, - "sum": _columns_coalescing_sum, - np.sum: _columns_coalescing_sum, - "min": _columns_coalescing_min, - "max": _columns_coalescing_max, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - np.min: _columns_coalescing_min, - np.max: _columns_coalescing_max, -} +def _create_pandas_to_snowpark_pandas_aggregation_map( + pandas_functions: Iterable[AggFuncTypeBase], + snowpark_pandas_aggregation: _SnowparkPandasAggregation, +) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]: + """ + Create a map from the given pandas functions to the given _SnowparkPandasAggregation. -# These functions are called instead if skipna=False -SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "min": least, - "max": greatest, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark - # sum_, since Snowpark sum_ gets the sum of all rows within a single column. - "sum": lambda *cols: sum(cols), - np.sum: lambda *cols: sum(cols), - np.min: least, - np.max: greatest, -} + Args; + pandas_functions: The pandas functions that map to the given aggregation. + snowpark_pandas_aggregation: The aggregation to map to + + Returns: + The map. + """ + return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions}) + + +# Map between the pandas input aggregation function (str or numpy function) and +# _SnowparkPandasAggregation representing information about applying the +# aggregation in Snowpark pandas. +_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[ + AggFuncTypeBase, _SnowparkPandasAggregation +] = MappingProxyType( + { + "count": _SnowparkPandasAggregation( + axis_0_aggregation=count, + axis_1_aggregation_skipna=_columns_count, + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("mean", np.mean), + _SnowparkPandasAggregation( + axis_0_aggregation=mean, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("min", np.min), + _SnowparkPandasAggregation( + axis_0_aggregation=min_, + axis_1_aggregation_keepna=least, + axis_1_aggregation_skipna=_columns_coalescing_min, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("max", np.max), + _SnowparkPandasAggregation( + axis_0_aggregation=max_, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=_columns_coalescing_max, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("sum", np.sum), + _SnowparkPandasAggregation( + axis_0_aggregation=sum_, + # IMPORTANT: count and sum use python builtin sum to invoke + # __add__ on each column rather than Snowpark sum_, since + # Snowpark sum_ gets the sum of all rows within a single column. + axis_1_aggregation_keepna=lambda *cols: sum(cols), + axis_1_aggregation_skipna=_columns_coalescing_sum, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("median", np.median), + _SnowparkPandasAggregation( + axis_0_aggregation=median, + preserves_snowpark_pandas_types=True, + ), + ), + "idxmax": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "idxmin": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "skew": _SnowparkPandasAggregation( + axis_0_aggregation=skew, + preserves_snowpark_pandas_types=True, + ), + "all": _SnowparkPandasAggregation( + # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("booland_agg")(col(c)), pandas_lit(True) + ), + preserves_snowpark_pandas_types=False, + ), + "any": _SnowparkPandasAggregation( + # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("boolor_agg")(col(c)), pandas_lit(False) + ), + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("std", np.std), + _SnowparkPandasAggregation( + axis_0_aggregation=stddev, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("var", np.var), + _SnowparkPandasAggregation( + axis_0_aggregation=variance, + # variance units are the square of the input column units, so + # variance does not preserve types. + preserves_snowpark_pandas_types=False, + ), + ), + "array_agg": _SnowparkPandasAggregation( + axis_0_aggregation=array_agg, + preserves_snowpark_pandas_types=False, + ), + "quantile": _SnowparkPandasAggregation( + axis_0_aggregation=column_quantile, + preserves_snowpark_pandas_types=True, + ), + "nunique": _SnowparkPandasAggregation( + axis_0_aggregation=count_distinct, + preserves_snowpark_pandas_types=False, + ), + } +) class AggregateColumnOpParameters(NamedTuple): @@ -462,7 +574,7 @@ class AggregateColumnOpParameters(NamedTuple): agg_snowflake_quoted_identifier: str # the snowflake aggregation function to apply on the column - snowflake_agg_func: Callable + snowflake_agg_func: SnowflakeAggFunc # the columns specifying the order of rows in the column. This is only # relevant for aggregations that depend on row order, e.g. summing a string @@ -471,88 +583,108 @@ class AggregateColumnOpParameters(NamedTuple): def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: - return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP + return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION def get_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0 -) -> Optional[Callable]: + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] +) -> Optional[SnowflakeAggFunc]: """ Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. If no corresponding snowflake aggregation function can be found, return None. """ - if axis == 0: - snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func) - if snowflake_agg_func == stddev or snowflake_agg_func == variance: - # for aggregation function std and var, we only support ddof = 0 or ddof = 1. - # when ddof is 1, std is mapped to stddev, var is mapped to variance - # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop - # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 - ddof = agg_kwargs.get("ddof", 1) - if ddof != 1 and ddof != 0: - return None - if ddof == 0: - return stddev_pop if snowflake_agg_func == stddev else var_pop - elif snowflake_agg_func == column_quantile: - interpolation = agg_kwargs.get("interpolation", "linear") - q = agg_kwargs.get("q", 0.5) - if interpolation not in ("linear", "nearest"): - return None - if not is_scalar(q): - # SNOW-1062878 Because list-like q would return multiple rows, calling quantile - # through the aggregate frontend in this manner is unsupported. - return None - return lambda col: column_quantile(col, interpolation, q) - elif agg_func in ("all", "any"): - # If there are no rows in the input frame, the function will also return NULL, which should - # instead by TRUE for "all" and FALSE for "any". - # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name - # as a string literal. - # The generated SQL expression for "all" is - # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE) - # The expression for "any" is - # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE) - default_value = bool(agg_func == "all") - return lambda col: builtin("ifnull")( - # mypy refuses to acknowledge snowflake_agg_func is non-NULL here - snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc] - pandas_lit(default_value), + if axis == 1: + return _generate_rowwise_aggregation_function(agg_func, agg_kwargs) + + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + + if snowpark_pandas_aggregation is None: + # We don't have any implementation at all for this aggregation. + return None + + snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation + + if snowpark_aggregation is None: + # We don't have an implementation on axis=0 for this aggregation. + return None + + # Rewrite some aggregations according to `agg_kwargs.` + if snowpark_aggregation == stddev or snowpark_aggregation == variance: + # for aggregation function std and var, we only support ddof = 0 or ddof = 1. + # when ddof is 1, std is mapped to stddev, var is mapped to variance + # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop + # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 + ddof = agg_kwargs.get("ddof", 1) + if ddof != 1 and ddof != 0: + return None + if ddof == 0: + snowpark_aggregation = ( + stddev_pop if snowpark_aggregation == stddev else var_pop ) - else: - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + elif snowpark_aggregation == column_quantile: + interpolation = agg_kwargs.get("interpolation", "linear") + q = agg_kwargs.get("q", 0.5) + if interpolation not in ("linear", "nearest"): + return None + if not is_scalar(q): + # SNOW-1062878 Because list-like q would return multiple rows, calling quantile + # through the aggregate frontend in this manner is unsupported. + return None + + def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: + return column_quantile(col, interpolation, q) - return snowflake_agg_func + assert ( + snowpark_aggregation is not None + ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." + return SnowflakeAggFunc( + snowpark_aggregation=snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def generate_rowwise_aggregation_function( +def _generate_rowwise_aggregation_function( agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any] -) -> Optional[Callable]: +) -> Optional[SnowflakeAggFunc]: """ Get a callable taking *arg columns to apply for an aggregation. Unlike get_snowflake_agg_func, this function may return a wrapped composition of Snowflake builtin functions depending on the values of the specified kwargs. """ - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) - if not agg_kwargs.get("skipna", True): - snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get( - agg_func, snowflake_agg_func - ) + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + if snowpark_pandas_aggregation is None: + return None + snowpark_aggregation = ( + snowpark_pandas_aggregation.axis_1_aggregation_skipna + if agg_kwargs.get("skipna", True) + else snowpark_pandas_aggregation.axis_1_aggregation_keepna + ) + if snowpark_aggregation is None: + return None min_count = agg_kwargs.get("min_count", 0) if min_count > 0: + original_aggregation = snowpark_aggregation + # Create a case statement to check if the number of non-null values exceeds min_count # when min_count > 0, if the number of not NULL values is < min_count, return NULL. - def agg_func_wrapper(fn: Callable) -> Callable: - return lambda *cols: when( - _columns_count(*cols) < min_count, pandas_lit(None) - ).otherwise(fn(*cols)) + def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: + return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise( + original_aggregation(*cols) + ) - return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func) - return snowflake_agg_func + return SnowflakeAggFunc( + snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def is_supported_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int +def _is_supported_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ check if the aggregation function is supported with snowflake. Current supported @@ -570,8 +702,8 @@ def is_supported_snowflake_agg_func( return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None -def are_all_agg_funcs_supported_by_snowflake( - agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int +def _are_all_agg_funcs_supported_by_snowflake( + agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ Check if all aggregation functions in the given list are snowflake supported @@ -582,14 +714,14 @@ def are_all_agg_funcs_supported_by_snowflake( return False. """ return all( - is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs ) def check_is_aggregation_supported_in_snowflake( agg_func: AggFuncType, agg_kwargs: dict[str, Any], - axis: int, + axis: Literal[0, 1], ) -> bool: """ check if distributed implementation with snowflake is available for the aggregation @@ -608,18 +740,18 @@ def check_is_aggregation_supported_in_snowflake( if is_dict_like(agg_func): return all( ( - are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) if is_list_like(value) and not is_named_tuple(value) - else is_supported_snowflake_agg_func(value, agg_kwargs, axis) + else _is_supported_snowflake_agg_func(value, agg_kwargs, axis) ) for value in agg_func.values() ) elif is_list_like(agg_func): - return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) - return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) + return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) -def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: +def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: """ Is the given snowflake aggregation function needs to be applied on the numeric column. """ @@ -697,7 +829,7 @@ def drop_non_numeric_data_columns( ) -def generate_aggregation_column( +def _generate_aggregation_column( agg_column_op_params: AggregateColumnOpParameters, agg_kwargs: dict[str, Any], is_groupby_agg: bool, @@ -721,8 +853,14 @@ def generate_aggregation_column( SnowparkColumn after the aggregation function. The column is also aliased back to the original name """ snowpark_column = agg_column_op_params.snowflake_quoted_identifier - snowflake_agg_func = agg_column_op_params.snowflake_agg_func - if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( + snowflake_agg_func = agg_column_op_params.snowflake_agg_func.snowpark_aggregation + + if snowflake_agg_func in (variance, var_pop) and isinstance( + agg_column_op_params.data_type, TimedeltaType + ): + raise TypeError("timedelta64 type does not support var operations") + + if _is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( agg_column_op_params.data_type, BooleanType ): # if the column is a boolean column and the aggregation function requires numeric values, @@ -753,7 +891,7 @@ def generate_aggregation_column( # note that we always assume keepna for array_agg. TODO(SNOW-1040398): # make keepna treatment consistent across array_agg and other # aggregation methods. - agg_snowpark_column = array_agg_keepna( + agg_snowpark_column = _array_agg_keepna( snowpark_column, ordering_columns=agg_column_op_params.ordering_columns ) elif ( @@ -857,7 +995,7 @@ def aggregate_with_ordered_dataframe( is_groupby_agg = groupby_columns is not None agg_list: list[SnowparkColumn] = [ - generate_aggregation_column( + _generate_aggregation_column( agg_column_op_params=agg_col_op, agg_kwargs=agg_kwargs, is_groupby_agg=is_groupby_agg, @@ -973,7 +1111,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: ) -def generate_pandas_labels_for_agg_result_columns( +def _generate_pandas_labels_for_agg_result_columns( pandas_label: Hashable, num_levels: int, agg_func_list: list[AggFuncInfo], @@ -1102,7 +1240,7 @@ def generate_column_agg_info( ) # generate the pandas label and quoted identifier for the result aggregation columns, one # for each aggregation function to apply. - agg_col_labels = generate_pandas_labels_for_agg_result_columns( + agg_col_labels = _generate_pandas_labels_for_agg_result_columns( pandas_label_to_identifier.pandas_label, num_levels, agg_func_list, # type: ignore[arg-type] diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 3bf1062107e..e7a96b49ef1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -520,12 +520,15 @@ def single_pivot_helper( data_column_snowflake_quoted_identifiers: new data column snowflake quoted identifiers this pivot result data_column_pandas_labels: new data column pandas labels for this pivot result """ - snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {}) - if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func): + snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0) + if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func( + snowflake_agg_func.snowpark_aggregation + ): # TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations raise ErrorMessage.not_implemented( f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments." ) + snowpark_aggr_func = snowflake_agg_func.snowpark_aggregation pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair @@ -1231,17 +1234,19 @@ def get_margin_aggregation( Returns: Snowpark column expression for the aggregation function result. """ - resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}) + resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}, axis=0) # This would have been resolved during the original pivot at an early stage. assert resolved_aggfunc is not None, "resolved_aggfunc is None" - aggfunc_expr = resolved_aggfunc(snowflake_quoted_identifier) + aggregation_expression = resolved_aggfunc.snowpark_aggregation( + snowflake_quoted_identifier + ) - if resolved_aggfunc == sum_: - aggfunc_expr = coalesce(aggfunc_expr, pandas_lit(0)) + if resolved_aggfunc.snowpark_aggregation == sum_: + aggregation_expression = coalesce(aggregation_expression, pandas_lit(0)) - return aggfunc_expr + return aggregation_expression def expand_pivot_result_with_pivot_table_margins_no_groupby_columns( diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b5022bff46b..2f6ff69be6c 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -149,8 +149,6 @@ ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, - GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, - GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, AggFuncInfo, AggFuncWithLabel, AggregateColumnOpParameters, @@ -161,7 +159,6 @@ convert_agg_func_arg_to_col_agg_func_map, drop_non_numeric_data_columns, generate_column_agg_info, - generate_rowwise_aggregation_function, get_agg_func_to_col_map, get_pandas_aggr_func_name, get_snowflake_agg_func, @@ -3556,41 +3553,23 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) - # Get the column aggregation functions used to check if the function - # preserves Snowpark pandas types. - agg_col_funcs = [] - for _, func in column_to_agg_func.items(): - if is_list_like(func) and not is_named_tuple(func): - for fn in func: - agg_col_funcs.append(fn.func) - else: - agg_col_funcs.append(func.func) + # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] new_data_column_quoted_identifiers = [] new_data_column_snowpark_pandas_types = [] - for i in range(len(agg_col_ops)): - col_agg_op = agg_col_ops[i] - col_agg_func = agg_col_funcs[i] - new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label) + for agg_col_op in agg_col_ops: + new_data_column_pandas_labels.append(agg_col_op.agg_pandas_label) new_data_column_quoted_identifiers.append( - col_agg_op.agg_snowflake_quoted_identifier + agg_col_op.agg_snowflake_quoted_identifier + ) + new_data_column_snowpark_pandas_types.append( + agg_col_op.data_type + if isinstance(agg_col_op.data_type, SnowparkPandasType) + and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types + else None ) - if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: - new_data_column_snowpark_pandas_types.append( - col_agg_op.data_type - if isinstance(col_agg_op.data_type, SnowparkPandasType) - else None - ) - elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: - # In the case where the aggregation overrides the type of the output data column - # (e.g. any always returns boolean data columns), set the output Snowpark pandas type - # of the given column to None - new_data_column_snowpark_pandas_types.append(None) # type: ignore - else: - self._raise_not_implemented_error_for_timedelta() - new_data_column_snowpark_pandas_types = None # type: ignore # The ordering of the named aggregations is changed by us when we process # the agg_kwargs into the func dict (named aggregations on the same @@ -3644,7 +3623,7 @@ def convert_func_to_agg_func_info( ), agg_pandas_label=None, agg_snowflake_quoted_identifier=row_position_quoted_identifier, - snowflake_agg_func=min_, + snowflake_agg_func=get_snowflake_agg_func("min", agg_kwargs={}, axis=0), ordering_columns=internal_frame.ordering_columns, ) agg_col_ops.append(row_position_agg_column_op) @@ -5761,9 +5740,9 @@ def agg( pandas_column_labels=frame.data_column_pandas_labels, ) if agg_arg in ("idxmin", "idxmax") - else generate_rowwise_aggregation_function(agg_arg, kwargs)( - *(col(c) for c in data_col_identifiers) - ) + else get_snowflake_agg_func( + agg_arg, kwargs, axis=1 + ).snowpark_aggregation(*(col(c) for c in data_col_identifiers)) for agg_arg in agg_args } pandas_labels = list(agg_col_map.keys()) @@ -13613,6 +13592,16 @@ def _window_agg( } ).frame else: + snowflake_agg_func = get_snowflake_agg_func(agg_func, agg_kwargs, axis=0) + if snowflake_agg_func is None: + # We don't have test coverage for this situation because we + # test individual rolling and expanding methods we've implemented, + # like rolling_sum(), but other rolling methods raise + # NotImplementedError immediately. We also don't support rolling + # agg(), which might take us here. + ErrorMessage.not_implemented( # pragma: no cover + f"Window aggregation does not support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)}" + ) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { # If aggregation is count use count on row_position_quoted_identifier @@ -13623,7 +13612,7 @@ def _window_agg( if agg_func == "count" else count(col(quoted_identifier)).over(window_expr) >= min_periods, - get_snowflake_agg_func(agg_func, agg_kwargs)( + snowflake_agg_func.snowpark_aggregation( # Expanding is cumulative so replace NULL with 0 for sum aggregation builtin("zeroifnull")(col(quoted_identifier)) if window_func == WindowFunction.EXPANDING diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 1dbb743aa32..87a4de75c1d 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -36,6 +36,7 @@ from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AggregateColumnOpParameters, + SnowflakeAggFunc, aggregate_with_ordered_dataframe, ) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( @@ -440,7 +441,14 @@ def mean( pandas_labels=["mean"] )[0] agg_column_op_params = AggregateColumnOpParameters( - index_id, LongType(), "mean", new_index_id, fn.mean, [] + index_id, + LongType(), + "mean", + new_index_id, + snowflake_agg_func=SnowflakeAggFunc( + preserves_snowpark_pandas_types=True, snowpark_aggregation=fn.mean + ), + ordering_columns=[], ) mean_value = aggregate_with_ordered_dataframe( frame.ordered_dataframe, [agg_column_op_params], {"skipna": skipna} diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 09acd49bb21..d136551dafe 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1124,6 +1124,27 @@ def test_timedelta(agg_func, by): ) +@sql_count_checker(query_count=1) +def test_groupby_timedelta_var(): + """ + Test that we can group by a timedelta column and take var() of an integer column. + + Note that we can't take the groupby().var() of the timedelta column because + var() is not defined for timedelta, in pandas or in Snowpark pandas. + """ + eval_snowpark_pandas_result( + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: df.groupby("A").var(), + ) + + def test_timedelta_groupby_agg(): native_df = native_pd.DataFrame( { diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py index a009e1089b0..0c9c056c2a7 100644 --- a/tests/integ/modin/groupby/test_groupby_negative.py +++ b/tests/integ/modin/groupby/test_groupby_negative.py @@ -18,6 +18,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + create_test_dfs, eval_snowpark_pandas_result, ) @@ -559,20 +560,12 @@ def test_groupby_agg_invalid_min_count( @sql_count_checker(query_count=0) -def test_groupby_var_no_support_for_timedelta(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - with pytest.raises( - NotImplementedError, - match=re.escape( - "SnowflakeQueryCompiler::groupby_agg is not yet implemented for Timedelta Type" +def test_timedelta_var_invalid(): + eval_snowpark_pandas_result( + *create_test_dfs( + [["key0", pd.Timedelta(1)]], ), - ): - snow_df.groupby("B").var() + lambda df: df.groupby(0).var(), + expect_exception=True, + expect_exception_type=TypeError, + ) diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index 5434387ba71..6c9edfd024f 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -2,12 +2,20 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from types import MappingProxyType +from unittest import mock + import numpy as np import pytest +import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils +from snowflake.snowpark.functions import greatest, sum as sum_ from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + SnowflakeAggFunc, + _is_supported_snowflake_agg_func, + _SnowparkPandasAggregation, check_is_aggregation_supported_in_snowflake, - is_supported_snowflake_agg_func, + get_snowflake_agg_func, ) @@ -53,8 +61,8 @@ ("quantile", {}, 1, False), ], ) -def test_is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: - assert is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid +def test__is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: + assert _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid @pytest.mark.parametrize( @@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args( agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0 ) assert can_be_distributed == expected_result + + +@pytest.mark.parametrize( + "agg_func, agg_kwargs, axis, expected", + [ + (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)), + ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)), + ("test", {}, 0, None), + ], +) +def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected): + result = get_snowflake_agg_func(agg_func, agg_kwargs, axis) + if expected is None: + assert result is None + else: + assert result == expected + + +def test_get_snowflake_agg_func_with_no_implementation_on_axis_0(): + """Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0.""" + # We have to patch the internal dictionary + # _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is + # no real function that we support on axis=1 but not on axis=0. + with mock.patch.object( + aggregation_utils, + "_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION", + MappingProxyType( + { + "max": _SnowparkPandasAggregation( + preserves_snowpark_pandas_types=True, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=greatest, + ) + } + ), + ): + assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None From f566e25f5950757b071d644e3272de8b7b40066d Mon Sep 17 00:00:00 2001 From: Jianzhun Du <68252326+sfc-gh-jdu@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:38:56 -0700 Subject: [PATCH 03/22] SNOW-1641644: Drop temp table directly at garbage collection instead of using multi-threading (#2214) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1641644 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. see details in [discussion](https://snowflake.slack.com/archives/C04HF38JFAQ/p1725057972543059?thread_ts=1724782971.351959&cid=C04HF38JFAQ) --- src/snowflake/snowpark/_internal/telemetry.py | 52 ++++++++ .../_internal/temp_table_auto_cleaner.py | 89 ++++++------- src/snowflake/snowpark/session.py | 39 ++++-- tests/integ/test_telemetry.py | 48 +++++++ tests/integ/test_temp_table_cleanup.py | 117 ++++++++++++------ tests/unit/test_session.py | 1 + 6 files changed, 247 insertions(+), 99 deletions(-) diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 8b9ef2acccb..aef60828334 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -79,6 +79,20 @@ class TelemetryField(Enum): QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" QUERY_PLAN_COMPLEXITY = "query_plan_complexity" + # temp table cleanup + TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup" + NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned" + NUM_TEMP_TABLES_CREATED = "num_temp_tables_created" + TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled" + TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = ( + "snowpark_temp_table_cleanup_abnormal_exception" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = ( + "temp_table_cleanup_abnormal_exception_table_name" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = ( + "temp_table_cleanup_abnormal_exception_message" + ) # These DataFrame APIs call other DataFrame APIs @@ -464,3 +478,41 @@ def send_large_query_optimization_skipped_telemetry( }, } self.send(message) + + def send_temp_table_cleanup_telemetry( + self, + session_id: str, + temp_table_cleaner_enabled: bool, + num_temp_tables_cleaned: int, + num_temp_tables_created: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled, + TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned, + TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created, + }, + } + self.send(message) + + def send_temp_table_cleanup_abnormal_exception_telemetry( + self, + session_id: str, + table_name: str, + exception_message: str, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index b9055c6fc58..4fa17498d34 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -4,9 +4,7 @@ import logging import weakref from collections import defaultdict -from queue import Empty, Queue -from threading import Event, Thread -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable @@ -33,12 +31,6 @@ def __init__(self, session: "Session") -> None: # to its reference count for later temp table management # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) - # unused temp table will be put into the queue for cleanup - self.queue: Queue = Queue() - # thread for removing temp tables (running DROP TABLE sql) - self.cleanup_thread: Optional[Thread] = None - # An event managing a flag that indicates whether the cleaner is started - self.stop_event = Event() def add(self, table: SnowflakeTable) -> None: self.ref_count_map[table.name] += 1 @@ -46,61 +38,60 @@ def add(self, table: SnowflakeTable) -> None: # and this table will be dropped finally _ = weakref.finalize(table, self._delete_ref_count, table.name) - def _delete_ref_count(self, name: str) -> None: + def _delete_ref_count(self, name: str) -> None: # pragma: no cover """ Decrements the reference count of a temporary table, and if the count reaches zero, puts this table in the queue for cleanup. """ self.ref_count_map[name] -= 1 if self.ref_count_map[name] == 0: - self.ref_count_map.pop(name) - # clean up - self.queue.put(name) + if self.session.auto_clean_up_temp_table_enabled: + self.drop_table(name) elif self.ref_count_map[name] < 0: logging.debug( f"Unexpected reference count {self.ref_count_map[name]} for table {name}" ) - def process_cleanup(self) -> None: - while not self.stop_event.is_set(): - try: - # it's non-blocking after timeout and become interruptable with stop_event - # it will raise an `Empty` exception if queue is empty after timeout, - # then we catch this exception and avoid breaking loop - table_name = self.queue.get(timeout=1) - self.drop_table(table_name) - except Empty: - continue - - def drop_table(self, name: str) -> None: + def drop_table(self, name: str) -> None: # pragma: no cover common_log_text = f"temp table {name} in session {self.session.session_id}" - logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}") + logging.debug(f"Ready to drop {common_log_text}") + query_id = None try: - # TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported - with self.session._conn._conn.cursor() as cursor: - cursor.execute( - f"drop table if exists {name} /* internal query to drop unused temp table */", - _statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}, + async_job = self.session.sql( + f"drop table if exists {name} /* internal query to drop unused temp table */", + )._internal_collect_with_tag_no_telemetry( + block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name} + ) + query_id = async_job.query_id + logging.debug(f"Dropping {common_log_text} with query id {query_id}") + except Exception as ex: # pragma: no cover + warning_message = f"Failed to drop {common_log_text}, exception: {ex}" + logging.warning(warning_message) + if query_id is None: + # If no query_id is available, it means the query haven't been accepted by gs, + # and it won't occur in our job_etl_view, send a separate telemetry for recording. + self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry( + self.session.session_id, + name, + str(ex), ) - logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}") - except Exception as ex: - logging.warning( - f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}" - ) # pragma: no cover - - def is_alive(self) -> bool: - return self.cleanup_thread is not None and self.cleanup_thread.is_alive() - - def start(self) -> None: - self.stop_event.clear() - if not self.is_alive(): - self.cleanup_thread = Thread(target=self.process_cleanup) - self.cleanup_thread.start() def stop(self) -> None: """ - The cleaner will stop immediately and leave unfinished temp tables in the queue. + Stops the cleaner (no-op) and sends the telemetry. """ - self.stop_event.set() - if self.is_alive(): - self.cleanup_thread.join() + self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry( + self.session.session_id, + temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled, + num_temp_tables_cleaned=self.num_temp_tables_cleaned, + num_temp_tables_created=self.num_temp_tables_created, + ) + + @property + def num_temp_tables_created(self) -> int: + return len(self.ref_count_map) + + @property + def num_temp_tables_cleaned(self) -> int: + # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled + return sum(v == 0 for v in self.ref_count_map.values()) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 8da0794f139..8ffd4081473 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -582,9 +582,6 @@ def __init__( self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - if self._auto_clean_up_temp_table_enabled: - self._temp_table_auto_cleaner.start() - _logger.info("Snowpark Session information: %s", self._session_info) def __enter__(self): @@ -623,8 +620,8 @@ def close(self) -> None: raise SnowparkClientExceptionMessages.SERVER_FAILED_CLOSE_SESSION(str(ex)) finally: try: - self._conn.close() self._temp_table_auto_cleaner.stop() + self._conn.close() _logger.info("Closed session: %s", self._session_id) finally: _remove_session(self) @@ -658,10 +655,33 @@ def auto_clean_up_temp_table_enabled(self) -> bool: :meth:`DataFrame.cache_result` in the current session when the DataFrame is no longer referenced (i.e., gets garbage collected). The default value is ``False``. + Example:: + + >>> import gc + >>> + >>> def f(session: Session) -> str: + ... df = session.create_dataframe( + ... [[1, 2], [3, 4]], schema=["a", "b"] + ... ).cache_result() + ... return df.table_name + ... + >>> session.auto_clean_up_temp_table_enabled = True + >>> table_name = f(session) + >>> assert table_name + >>> gc.collect() # doctest: +SKIP + >>> + >>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced + >>> # outside the function + >>> session.sql(f"show tables like '{table_name}'").count() + 0 + + >>> session.auto_clean_up_temp_table_enabled = False + Note: - Even if this parameter is ``False``, Snowpark still records temporary tables when - their corresponding DataFrame are garbage collected. Therefore, if you turn it on in the middle of your session or after turning it off, - the target temporary tables will still be cleaned up accordingly. + Temporary tables will only be dropped if this parameter is enabled during garbage collection. + If a temporary table is no longer referenced when the parameter is on, it will be dropped during garbage collection. + However, if garbage collection occurs while the parameter is off, the table will not be removed. + Note that Python's garbage collection is triggered opportunistically, with no guaranteed timing. """ return self._auto_clean_up_temp_table_enabled @@ -755,11 +775,6 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None: self._session_id, value ) self._auto_clean_up_temp_table_enabled = value - is_alive = self._temp_table_auto_cleaner.is_alive() - if value and not is_alive: - self._temp_table_auto_cleaner.start() - elif not value and is_alive: - self._temp_table_auto_cleaner.stop() else: raise ValueError( "value for auto_clean_up_temp_table_enabled must be True or False!" diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 7aaa5c9e5dd..39749de76f6 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -1223,3 +1223,51 @@ def send_telemetry(): data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) assert data == expected_data assert type_ == "snowpark_compilation_stage_statistics" + + +def test_temp_table_cleanup(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_telemetry( + session.session_id, + temp_table_cleaner_enabled=True, + num_temp_tables_cleaned=2, + num_temp_tables_created=5, + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleaner_enabled": True, + "num_temp_tables_cleaned": 2, + "num_temp_tables_created": 5, + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup" + + +def test_temp_table_cleanup_exception(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_abnormal_exception_telemetry( + session.session_id, + table_name="table_name_placeholder", + exception_message="exception_message_placeholder", + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleanup_abnormal_exception_table_name": "table_name_placeholder", + "temp_table_cleanup_abnormal_exception_message": "exception_message_placeholder", + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup_abnormal_exception" diff --git a/tests/integ/test_temp_table_cleanup.py b/tests/integ/test_temp_table_cleanup.py index 4ac87661484..cdd97d49937 100644 --- a/tests/integ/test_temp_table_cleanup.py +++ b/tests/integ/test_temp_table_cleanup.py @@ -12,6 +12,7 @@ from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, + warning_dict, ) from snowflake.snowpark.functions import col from tests.utils import IS_IN_STORED_PROC @@ -25,40 +26,61 @@ WAIT_TIME = 1 +@pytest.fixture(autouse=True) +def setup(session): + auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled + session.auto_clean_up_temp_table_enabled = True + yield + session.auto_clean_up_temp_table_enabled = auto_clean_up_temp_table_enabled + + def test_basic(session): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = df1.select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df3 = df1.union_all(df2) df3.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df2 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df3 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 def test_function(session): + session._temp_table_auto_cleaner.ref_count_map.clear() table_name = None def f(session: Session) -> None: @@ -68,13 +90,16 @@ def f(session: Session) -> None: nonlocal table_name table_name = df.table_name assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() f(session) gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_name.split(".")) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.parametrize( @@ -86,33 +111,42 @@ def f(session: Session) -> None: ], ) def test_copy(session, copy_function): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = copy_function(df1).select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 2 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_reference_count_map_multiple_sessions(db_parameters, session): + session._temp_table_auto_cleaner.ref_count_map.clear() new_session = Session.builder.configs(db_parameters).create() + new_session.auto_clean_up_temp_table_enabled = True try: df1 = session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] @@ -120,43 +154,59 @@ def test_reference_count_map_multiple_sessions(db_parameters, session): table_name1 = df1.table_name table_ids1 = table_name1.split(".") assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 1 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = new_session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).cache_result() table_name2 = df2.table_name table_ids2 = table_name2.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids1) assert new_session._table_exists(table_ids2) assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - new_session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not new_session._table_exists(table_ids2) - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 finally: new_session.close() def test_save_as_table_no_drop(session): - session._temp_table_auto_cleaner.start() + session._temp_table_auto_cleaner.ref_count_map.clear() def f(session: Session, temp_table_name: str) -> None: session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).write.save_as_table(temp_table_name, table_type="temp") - assert session._temp_table_auto_cleaner.ref_count_map[temp_table_name] == 0 + assert temp_table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) f(session, temp_table_name) @@ -165,34 +215,25 @@ def f(session: Session, temp_table_name: str) -> None: assert session._table_exists([temp_table_name]) -def test_start_stop(session): - session._temp_table_auto_cleaner.stop() - - df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() - table_name = df1.table_name +def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): + warning_dict.clear() + with caplog.at_level(logging.WARNING): + session.auto_clean_up_temp_table_enabled = False + assert session.auto_clean_up_temp_table_enabled is False + assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() + table_name = df.table_name table_ids = table_name.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 - del df1 + del df gc.collect() - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 - assert not session._temp_table_auto_cleaner.queue.empty() - assert session._table_exists(table_ids) - - session._temp_table_auto_cleaner.start() time.sleep(WAIT_TIME) - assert session._temp_table_auto_cleaner.queue.empty() - assert not session._table_exists(table_ids) - - -def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): - with caplog.at_level(logging.WARNING): - session.auto_clean_up_temp_table_enabled = True + assert session._table_exists(table_ids) + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + session.auto_clean_up_temp_table_enabled = True assert session.auto_clean_up_temp_table_enabled is True - assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text - assert session._temp_table_auto_cleaner.is_alive() - session.auto_clean_up_temp_table_enabled = False - assert session.auto_clean_up_temp_table_enabled is False - assert not session._temp_table_auto_cleaner.is_alive() + with pytest.raises( ValueError, match="value for auto_clean_up_temp_table_enabled must be True or False!", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 262c9e82c44..370ee455d62 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -112,6 +112,7 @@ def test_used_scoped_temp_object(): def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._telemetry_client = mock.Mock() fake_connection.is_closed = MagicMock(return_value=False) exception_msg = "Mock exception for session.cancel_all" fake_connection.run_query = MagicMock(side_effect=Exception(exception_msg)) From c7be18c88838743145aa0e1ade59711ab5564d3a Mon Sep 17 00:00:00 2001 From: Mahesh Vashishtha Date: Fri, 13 Sep 2024 13:55:38 -0700 Subject: [PATCH 04/22] SNOW-1653121: Support some Timedelta aggregations on axis=0. (#2248) Fixes SNOW-1653121 Test and support aggregation on axis=0. We still raise `NotImplementedError` if: 1) the aggregation requires concatenating a frame with timedelta types 2) the aggregation requires transposing a row containing a timedelta type and other types. This change also fixes the bug that timedelta aggregations like mean would produce the wrong type (and the wrong result) by truncating the float result if `preserves_snowpark_pandas_type`. --------- Signed-off-by: sfc-gh-mvashishtha --- CHANGELOG.md | 1 + .../plugin/_internal/aggregation_utils.py | 16 ++ .../compiler/snowflake_query_compiler.py | 29 ++-- .../plugin/extensions/timedelta_index.py | 44 +++--- tests/integ/modin/conftest.py | 27 ++++ tests/integ/modin/frame/test_aggregate.py | 102 +++++++++++++ tests/integ/modin/frame/test_describe.py | 15 ++ tests/integ/modin/frame/test_idxmax_idxmin.py | 14 +- tests/integ/modin/frame/test_nunique.py | 11 +- tests/integ/modin/frame/test_skew.py | 32 +++- tests/integ/modin/groupby/test_all_any.py | 30 +++- .../modin/groupby/test_groupby_basic_agg.py | 142 +++++++++--------- .../modin/groupby/test_groupby_first_last.py | 11 ++ tests/integ/modin/groupby/test_quantile.py | 8 + tests/integ/modin/index/conftest.py | 1 + tests/integ/modin/index/test_all_any.py | 3 + tests/integ/modin/index/test_argmax_argmin.py | 12 ++ tests/integ/modin/series/test_aggregate.py | 67 +++++++++ .../integ/modin/series/test_argmax_argmin.py | 5 + tests/integ/modin/series/test_describe.py | 16 ++ .../series/test_first_last_valid_index.py | 4 + .../integ/modin/series/test_idxmax_idxmin.py | 5 + tests/integ/modin/series/test_nunique.py | 15 ++ 23 files changed, 487 insertions(+), 123 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a358..b19e909930e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ #### New Features - Added support for `TimedeltaIndex.mean` method. +- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. ## 1.22.1 (2024-09-11) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 3d25b1273b5..0005df924db 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -56,6 +56,7 @@ stddev, stddev_pop, sum as sum_, + trunc, var_pop, variance, when, @@ -698,6 +699,8 @@ def _is_supported_snowflake_agg_func( is_valid: bool. Whether it is valid to implement with snowflake or not. """ if isinstance(agg_func, tuple) and len(agg_func) == 2: + # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, + # take the second part of the named aggregation. agg_func = agg_func[0] return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None @@ -963,6 +966,19 @@ def _generate_aggregation_column( ), f"No case expression is constructed with skipna({skipna}), min_count({min_count})" agg_snowpark_column = case_expr.otherwise(agg_snowpark_column) + if ( + isinstance(agg_column_op_params.data_type, TimedeltaType) + and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types + ): + # timedelta aggregations that produce timedelta results might produce + # a decimal type in snowflake, e.g. + # pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in + # Snowflake. We truncate the decimal part of the result, as pandas + # does. + agg_snowpark_column = cast( + trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type + ) + # rename the column to agg_column_quoted_identifier agg_snowpark_column = agg_snowpark_column.as_( agg_column_op_params.agg_snowflake_quoted_identifier diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 2f6ff69be6c..38402e5b984 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -3553,7 +3553,6 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) - # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] @@ -3570,7 +3569,6 @@ def convert_func_to_agg_func_info( and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types else None ) - # The ordering of the named aggregations is changed by us when we process # the agg_kwargs into the func dict (named aggregations on the same # column are moved to be contiguous, see groupby.py::aggregate for an @@ -5636,8 +5634,6 @@ def agg( args: the arguments passed for the aggregation kwargs: keyword arguments passed for the aggregation function. """ - self._raise_not_implemented_error_for_timedelta() - numeric_only = kwargs.get("numeric_only", False) # Call fallback if the aggregation function passed in the arg is currently not supported # by snowflake engine. @@ -5683,6 +5679,11 @@ def agg( not is_list_like(value) for value in func.values() ) if axis == 1: + if any( + isinstance(t, TimedeltaType) + for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values() + ): + ErrorMessage.not_implemented_for_timedelta("agg(axis=1)") if self.is_multiindex(): # TODO SNOW-1010307 fix axis=1 behavior with MultiIndex ErrorMessage.not_implemented( @@ -5862,7 +5863,13 @@ def generate_agg_qc( index_column_snowflake_quoted_identifiers=[ agg_name_col_quoted_identifier ], - data_column_types=None, + data_column_types=[ + col.data_type + if isinstance(col.data_type, SnowparkPandasType) + and col.snowflake_agg_func.preserves_snowpark_pandas_types + else None + for col in col_agg_infos + ], index_column_types=None, ) return SnowflakeQueryCompiler(single_agg_dataframe) @@ -9108,7 +9115,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler": SnowflakeQueryCompiler Transposed new QueryCompiler object. """ - self._raise_not_implemented_error_for_timedelta() + if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1: + # In this case, transpose may lose types. + self._raise_not_implemented_error_for_timedelta() frame = self._modin_frame @@ -12492,8 +12501,6 @@ def _quantiles_single_col( column would allow us to create an accurate row position column, but would require a potentially expensive JOIN operator afterwards to apply the correct index labels. """ - self._raise_not_implemented_error_for_timedelta() - assert len(self._modin_frame.data_column_pandas_labels) == 1 if index is not None: @@ -12558,7 +12565,7 @@ def _quantiles_single_col( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_identifier], - data_column_types=None, + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, index_column_types=None, ) # We cannot call astype() directly to convert an index column, so we replicate @@ -14566,8 +14573,6 @@ def idxmax( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -14592,8 +14597,6 @@ def idxmin( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 87a4de75c1d..9cb4ffa7327 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -32,13 +32,7 @@ from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable from pandas.core.dtypes.common import is_timedelta64_dtype -from snowflake.snowpark import functions as fn from snowflake.snowpark.modin.pandas import DataFrame, Series -from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( - AggregateColumnOpParameters, - SnowflakeAggFunc, - aggregate_with_ordered_dataframe, -) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -46,7 +40,6 @@ from snowflake.snowpark.modin.plugin.utils.error_message import ( timedelta_index_not_implemented, ) -from snowflake.snowpark.types import LongType _CONSTRUCTOR_DEFAULTS = { "unit": lib.no_default, @@ -434,26 +427,25 @@ def mean( raise ValueError( f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'" ) - # TODO SNOW-1620439: Reuse code from Series.mean. - frame = self._query_compiler._modin_frame - index_id = frame.index_column_snowflake_quoted_identifiers[0] - new_index_id = frame.ordered_dataframe.generate_snowflake_quoted_identifiers( - pandas_labels=["mean"] - )[0] - agg_column_op_params = AggregateColumnOpParameters( - index_id, - LongType(), - "mean", - new_index_id, - snowflake_agg_func=SnowflakeAggFunc( - preserves_snowpark_pandas_types=True, snowpark_aggregation=fn.mean - ), - ordering_columns=[], + pandas_dataframe_result = ( + # reset_index(drop=False) copies the index column of + # self._query_compiler into a new data column. Use `drop=False` + # so that we don't have to use SQL row_number() to generate a new + # index column. + self._query_compiler.reset_index(drop=False) + # Aggregate the data column. + .agg("mean", axis=0, args=(), kwargs={"skipna": skipna}) + # convert the query compiler to a pandas dataframe with + # dimensions 1x1 (note that the frame has a single row even + # if `self` is empty.) + .to_pandas() ) - mean_value = aggregate_with_ordered_dataframe( - frame.ordered_dataframe, [agg_column_op_params], {"skipna": skipna} - ).collect()[0][0] - return native_pd.Timedelta(np.nan if mean_value is None else int(mean_value)) + assert pandas_dataframe_result.shape == ( + 1, + 1, + ), "Internal error: aggregation result is not 1x1." + # Return the only element in the frame. + return pandas_dataframe_result.iloc[0, 0] @timedelta_index_not_implemented() def as_unit(self, unit: str) -> TimedeltaIndex: diff --git a/tests/integ/modin/conftest.py b/tests/integ/modin/conftest.py index 2f24954e769..a7217b38a50 100644 --- a/tests/integ/modin/conftest.py +++ b/tests/integ/modin/conftest.py @@ -715,3 +715,30 @@ def numeric_test_data_4x4(): "C": [7, 10, 13, 16], "D": [8, 11, 14, 17], } + + +@pytest.fixture +def timedelta_native_df() -> pandas.DataFrame: + return pandas.DataFrame( + { + "A": [ + pd.Timedelta(days=1), + pd.Timedelta(days=2), + pd.Timedelta(days=3), + pd.Timedelta(days=4), + ], + "B": [ + pd.Timedelta(minutes=-1), + pd.Timedelta(minutes=0), + pd.Timedelta(minutes=5), + pd.Timedelta(minutes=6), + ], + "C": [ + None, + pd.Timedelta(nanoseconds=5), + pd.Timedelta(nanoseconds=0), + pd.Timedelta(nanoseconds=4), + ], + "D": pandas.to_timedelta([pd.NaT] * 4), + } + ) diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index b018682b6f8..ba68ae13734 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -187,6 +187,108 @@ def test_string_sum_with_nulls(): assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"])) +class TestTimedelta: + """Test aggregating dataframes containing timedelta columns.""" + + @pytest.mark.parametrize( + "func, union_count", + [ + param( + lambda df: df.aggregate(["min"]), + 0, + id="aggregate_list_with_one_element", + ), + param(lambda df: df.aggregate(x=("A", "max")), 0, id="single_named_agg"), + # this works since all results are timedelta and we don't need to do any concats. + param( + lambda df: df.aggregate({"B": "mean", "A": "sum"}), + 0, + id="dict_producing_two_timedeltas", + ), + # this works since even though we need to do concats, all the results are non-timdelta. + param( + lambda df: df.aggregate(x=("B", "all"), y=("B", "any")), + 1, + id="named_agg_producing_two_bools", + ), + # note following aggregation requires transpose + param(lambda df: df.aggregate(max), 0, id="aggregate_max"), + param(lambda df: df.min(), 0, id="min"), + param(lambda df: df.max(), 0, id="max"), + param(lambda df: df.count(), 0, id="count"), + param(lambda df: df.sum(), 0, id="sum"), + param(lambda df: df.mean(), 0, id="mean"), + param(lambda df: df.median(), 0, id="median"), + param(lambda df: df.std(), 0, id="std"), + param(lambda df: df.quantile(), 0, id="single_quantile"), + param(lambda df: df.quantile([0.01, 0.99]), 1, id="two_quantiles"), + ], + ) + def test_supported_axis_0(self, func, union_count, timedelta_native_df): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + func, + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1653126") + def test_axis_1(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), lambda df: df.sum(axis=1) + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}), + lambda df: df.aggregate({"B": ["count"], "A": "sum", "C": ["max", "min"]}), + lambda df: df.aggregate( + x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count") + ), + lambda df: df.aggregate(["min", np.max]), + lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")), + lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")), + lambda df: df.aggregate( + {"B": ["idxmax"], "A": "sum", "C": ["max", "idxmin"]} + ), + ], + ) + def test_agg_requires_concat_with_timedelta(self, timedelta_native_df, operation): + eval_snowpark_pandas_result(*create_test_dfs(timedelta_native_df), operation) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires transposing a one-row frame with integer and timedelta.", + ) + def test_agg_produces_timedelta_and_non_timedelta_type(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.aggregate({"B": "idxmax", "A": "sum"}), + ) + + @pytest.mark.parametrize( "func, expected_union_count", [ diff --git a/tests/integ/modin/frame/test_describe.py b/tests/integ/modin/frame/test_describe.py index a9668c5794f..4f1882d441d 100644 --- a/tests/integ/modin/frame/test_describe.py +++ b/tests/integ/modin/frame/test_describe.py @@ -358,3 +358,18 @@ def test_describe_object_file(resources_path): df = pd.read_csv(test_files.test_concat_file1_csv) native_df = df.to_pandas() eval_snowpark_pandas_result(df, native_df, lambda x: x.describe(include="O")) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df.describe(), + ) diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py index 72fe88968bc..87041060bd2 100644 --- a/tests/integ/modin/frame/test_idxmax_idxmin.py +++ b/tests/integ/modin/frame/test_idxmax_idxmin.py @@ -196,8 +196,18 @@ def test_idxmax_idxmin_with_dates(func, axis): @sql_count_checker(query_count=1) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) -@pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.xfail(reason="SNOW-1625380 TODO") +@pytest.mark.parametrize( + "axis", + [ + 0, + pytest.param( + 1, + marks=pytest.mark.xfail( + strict=True, raises=NotImplementedError, reason="SNOW-1653126" + ), + ), + ], +) def test_idxmax_idxmin_with_timedelta(func, axis): native_df = native_pd.DataFrame( data={ diff --git a/tests/integ/modin/frame/test_nunique.py b/tests/integ/modin/frame/test_nunique.py index d0cad8ec2ad..78098d34386 100644 --- a/tests/integ/modin/frame/test_nunique.py +++ b/tests/integ/modin/frame/test_nunique.py @@ -11,8 +11,13 @@ from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result -TEST_LABELS = np.array(["A", "B", "C", "D"]) -TEST_DATA = [[0, 1, 2, 3], [0, 0, 0, 0], [None, 0, None, 0], [None, None, None, None]] +TEST_LABELS = np.array(["A", "B", "C", "D", "E"]) +TEST_DATA = [ + [0, 1, 2, 3, pd.Timedelta(4)], + [0, 0, 0, 0, pd.Timedelta(0)], + [None, 0, None, 0, pd.Timedelta(0)], + [None, None, None, None, None], +] # which original dataframe (constructed from slicing) to test for TEST_SLICES = [ @@ -80,7 +85,7 @@ def test_dataframe_nunique_no_columns(native_df): [ pytest.param(None, id="default_columns"), pytest.param( - [["bar", "bar", "baz", "foo"], ["one", "two", "one", "two"]], + [["bar", "bar", "baz", "foo", "foo"], ["one", "two", "one", "two", "one"]], id="2D_columns", ), ], diff --git a/tests/integ/modin/frame/test_skew.py b/tests/integ/modin/frame/test_skew.py index 72fad6cebdc..94b7fd79c24 100644 --- a/tests/integ/modin/frame/test_skew.py +++ b/tests/integ/modin/frame/test_skew.py @@ -8,7 +8,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import assert_series_equal +from tests.integ.modin.utils import ( + assert_series_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @sql_count_checker(query_count=1) @@ -62,16 +66,22 @@ def test_skew_basic(): }, "kwargs": {"numeric_only": True, "skipna": True}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": True, + }, + }, ], ) @sql_count_checker(query_count=1) def test_skew(data): - native_df = native_pd.DataFrame(data["frame"]) - snow_df = pd.DataFrame(native_df) - assert_series_equal( - snow_df.skew(**data["kwargs"]), - native_df.skew(**data["kwargs"]), - rtol=1.0e-5, + eval_snowpark_pandas_result( + *create_test_dfs(data["frame"]), + lambda df: df.skew(**data["kwargs"]), + rtol=1.0e-5 ) @@ -103,6 +113,14 @@ def test_skew(data): }, "kwargs": {"level": 2}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": False, + }, + }, ], ) @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py index d5234dfbdb5..df8df44d47c 100644 --- a/tests/integ/modin/groupby/test_all_any.py +++ b/tests/integ/modin/groupby/test_all_any.py @@ -14,7 +14,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_frame_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @pytest.mark.parametrize( @@ -109,3 +113,27 @@ def test_all_any_chained(): lambda df: df.apply(lambda ser: ser.str.len()) ) ) + + +@sql_count_checker(query_count=1) +def test_timedelta_any_with_nulls(): + """ + Test this case separately because pandas behavior is different from Snowpark pandas behavior. + + pandas bug that does not apply to Snowpark pandas: + https://github.com/pandas-dev/pandas/issues/59712 + """ + snow_df, native_df = create_test_dfs( + { + "key": ["a"], + "A": native_pd.Series([pd.NaT], dtype="timedelta64[ns]"), + }, + ) + assert_frame_equal( + native_df.groupby("key").any(), + native_pd.DataFrame({"A": [True]}, index=native_pd.Index(["a"], name="key")), + ) + assert_frame_equal( + snow_df.groupby("key").any(), + native_pd.DataFrame({"A": [False]}, index=native_pd.Index(["a"], name="key")), + ) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index d136551dafe..cbf5b75d48c 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1096,81 +1096,81 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): ) -@pytest.mark.parametrize( - "agg_func", - [ - "count", - "sum", - "mean", - "median", - "std", - ], -) -@pytest.mark.parametrize("by", ["A", "B"]) -@sql_count_checker(query_count=1) -def test_timedelta(agg_func, by): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - - eval_snowpark_pandas_result( - snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() - ) - - -@sql_count_checker(query_count=1) -def test_groupby_timedelta_var(): - """ - Test that we can group by a timedelta column and take var() of an integer column. - - Note that we can't take the groupby().var() of the timedelta column because - var() is not defined for timedelta, in pandas or in Snowpark pandas. - """ - eval_snowpark_pandas_result( - *create_test_dfs( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ), - lambda df: df.groupby("A").var(), - ) - - -def test_timedelta_groupby_agg(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - "C": [True, False, False, True], - } +class TestTimedelta: + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "method", + [ + "count", + "mean", + "min", + "max", + "idxmax", + "idxmin", + "sum", + "median", + "std", + "nunique", + ], ) - snow_df = pd.DataFrame(native_df) - with SqlCounter(query_count=1): + @pytest.mark.parametrize("by", ["A", "B"]) + def test_aggregation_methods(self, method, by): eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: getattr(df.groupby(by), method)(), ) - with SqlCounter(query_count=1): - eval_snowpark_pandas_result( - snow_df, - native_df, + + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}), + lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + lambda df: df.groupby("B").agg(["mean", "std"]), + lambda df: df.groupby("B").agg({"A": ["count", np.sum]}), + lambda df: df.groupby("B").agg({"A": "sum"}), + ], + ) + def test_agg(self, operation): + eval_snowpark_pandas_result( + *create_test_dfs( + native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + "C": [True, False, False, True], + } + ) + ), + operation, ) - with SqlCounter(query_count=1): + + @sql_count_checker(query_count=1) + def test_groupby_timedelta_var(self): + """ + Test that we can group by a timedelta column and take var() of an integer column. + + Note that we can't take the groupby().var() of the timedelta column because + var() is not defined for timedelta, in pandas or in Snowpark pandas. + """ eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: df.groupby("A").var(), ) diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py index 5da35806dd1..5e04d5a6fc2 100644 --- a/tests/integ/modin/groupby/test_groupby_first_last.py +++ b/tests/integ/modin/groupby/test_groupby_first_last.py @@ -46,6 +46,17 @@ [np.nan], ] ), + "col11_timedelta": [ + pd.Timedelta("1 days"), + None, + pd.Timedelta("2 days"), + None, + None, + None, + None, + None, + None, + ], } diff --git a/tests/integ/modin/groupby/test_quantile.py b/tests/integ/modin/groupby/test_quantile.py index b14299fee63..940d366a7e2 100644 --- a/tests/integ/modin/groupby/test_quantile.py +++ b/tests/integ/modin/groupby/test_quantile.py @@ -64,6 +64,14 @@ # ), # All NA ([np.nan] * 5, [np.nan] * 5), + pytest.param( + pd.timedelta_range( + "1 days", + "5 days", + ), + pd.timedelta_range("1 second", "5 second"), + id="timedelta", + ), ], ) @pytest.mark.parametrize("q", [0, 0.5, 1]) diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index 84454fc4a27..26afd232c4f 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -79,4 +79,5 @@ tz="America/Los_Angeles", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), + native_pd.TimedeltaIndex(["4 days", None, "-1 days", "5 days"]), ] diff --git a/tests/integ/modin/index/test_all_any.py b/tests/integ/modin/index/test_all_any.py index 267e7929ea1..499be6f03dc 100644 --- a/tests/integ/modin/index/test_all_any.py +++ b/tests/integ/modin/index/test_all_any.py @@ -25,6 +25,9 @@ native_pd.Index(["a", "b", "c", "d"]), native_pd.Index([5, None, 7]), native_pd.Index([], dtype="object"), + native_pd.Index([pd.Timedelta(0), None]), + native_pd.Index([pd.Timedelta(0)]), + native_pd.Index([pd.Timedelta(0), pd.Timedelta(1)]), ] NATIVE_INDEX_EMPTY_DATA = [ diff --git a/tests/integ/modin/index/test_argmax_argmin.py b/tests/integ/modin/index/test_argmax_argmin.py index 6d446a0a66a..7d42f3b88c9 100644 --- a/tests/integ/modin/index/test_argmax_argmin.py +++ b/tests/integ/modin/index/test_argmax_argmin.py @@ -18,6 +18,18 @@ native_pd.Index([4, None, 1, 3, 4, 1]), native_pd.Index([4, None, 1, 3, 4, 1], name="some name"), native_pd.Index([1, 10, 4, 3, 4]), + pytest.param( + native_pd.Index( + [ + pd.Timedelta(1), + pd.Timedelta(10), + pd.Timedelta(4), + pd.Timedelta(3), + pd.Timedelta(4), + ] + ), + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py index fa354fda1fc..c3e40828d94 100644 --- a/tests/integ/modin/series/test_aggregate.py +++ b/tests/integ/modin/series/test_aggregate.py @@ -1,6 +1,8 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import re + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -17,6 +19,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_series, eval_snowpark_pandas_result, ) @@ -358,3 +361,67 @@ def test_2_tuple_named_agg_errors_for_series(native_series, agg_kwargs): expect_exception_type=SpecificationError, assert_exception_equal=True, ) + + +class TestTimedelta: + """Test aggregating a timedelta series.""" + + @pytest.mark.parametrize( + "func, union_count, is_scalar", + [ + pytest.param(*v, id=str(i)) + for i, v in enumerate( + [ + (lambda series: series.aggregate(["min"]), 0, False), + (lambda series: series.aggregate({"A": "max"}), 0, False), + # this works since even though we need to do concats, all the results are non-timdelta. + (lambda df: df.aggregate(["all", "any", "count"]), 2, False), + # note following aggregation requires transpose + (lambda df: df.aggregate(max), 0, True), + (lambda df: df.min(), 0, True), + (lambda df: df.max(), 0, True), + (lambda df: df.count(), 0, True), + (lambda df: df.sum(), 0, True), + (lambda df: df.mean(), 0, True), + (lambda df: df.median(), 0, True), + (lambda df: df.std(), 0, True), + (lambda df: df.quantile(), 0, True), + (lambda df: df.quantile([0.01, 0.99]), 0, False), + ] + ) + ], + ) + def test_supported(self, func, union_count, timedelta_native_df, is_scalar): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + func, + comparator=validate_scalar_result + if is_scalar + else assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda series: series.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + def test_unsupported_due_to_concat(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda df: df.agg(["count", "max"]), + ) diff --git a/tests/integ/modin/series/test_argmax_argmin.py b/tests/integ/modin/series/test_argmax_argmin.py index 607b36a27f3..e212e3ba2dd 100644 --- a/tests/integ/modin/series/test_argmax_argmin.py +++ b/tests/integ/modin/series/test_argmax_argmin.py @@ -18,6 +18,11 @@ ([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]), ([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/series/test_describe.py b/tests/integ/modin/series/test_describe.py index 9ecd2e33a3d..0f7bbda6c3a 100644 --- a/tests/integ/modin/series/test_describe.py +++ b/tests/integ/modin/series/test_describe.py @@ -11,6 +11,7 @@ from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( assert_series_equal, + create_test_dfs, create_test_series, eval_snowpark_pandas_result, ) @@ -156,3 +157,18 @@ def test_describe_multiindex(data, index): eval_snowpark_pandas_result( *create_test_series(data, index=index), lambda ser: ser.describe() ) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df["A"].describe(), + ) diff --git a/tests/integ/modin/series/test_first_last_valid_index.py b/tests/integ/modin/series/test_first_last_valid_index.py index 1e8d052e10f..1930bdf1088 100644 --- a/tests/integ/modin/series/test_first_last_valid_index.py +++ b/tests/integ/modin/series/test_first_last_valid_index.py @@ -22,6 +22,10 @@ native_pd.Series([5, 6, 7, 8], index=["i", "am", "iron", "man"]), native_pd.Series([None, None, 2], index=[None, 1, 2]), native_pd.Series([None, None, 2], index=[None, None, None]), + pytest.param( + native_pd.Series([None, None, pd.Timedelta(2)], index=[None, 1, 2]), + id="timedelta", + ), ], ) def test_first_and_last_valid_index_series(native_series): diff --git a/tests/integ/modin/series/test_idxmax_idxmin.py b/tests/integ/modin/series/test_idxmax_idxmin.py index ea536240a42..e8e66a30f61 100644 --- a/tests/integ/modin/series/test_idxmax_idxmin.py +++ b/tests/integ/modin/series/test_idxmax_idxmin.py @@ -17,6 +17,11 @@ ([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]), ([1, None, 4, 3, 4], [None, "B", "C", "D", "E"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) diff --git a/tests/integ/modin/series/test_nunique.py b/tests/integ/modin/series/test_nunique.py index bb20e9e4a53..3856dbc516a 100644 --- a/tests/integ/modin/series/test_nunique.py +++ b/tests/integ/modin/series/test_nunique.py @@ -6,6 +6,7 @@ import numpy as np import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker @@ -32,6 +33,20 @@ [True, None, False, True, None], [1.1, "a", None] * 4, [native_pd.to_datetime("2023-12-01"), native_pd.to_datetime("1999-09-09")] * 2, + param( + [ + native_pd.Timedelta(1), + native_pd.Timedelta(1), + native_pd.Timedelta(2), + None, + None, + ], + id="timedelta_with_nulls", + ), + param( + [native_pd.Timedelta(1), native_pd.Timedelta(1), native_pd.Timedelta(2)], + id="timedelta_without_nulls", + ), ], ) @pytest.mark.parametrize("dropna", [True, False]) From 5b5c03bebd2c2c5f8a597498df45c5e730b635e3 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Fri, 13 Sep 2024 14:26:08 -0700 Subject: [PATCH 05/22] SNOW-1660952, SNOW-1660954: Add support for DatetimeIndex.tz_localize/tz_convert (#2281) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1660952, SNOW-1660954 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. SNOW-1660952, SNOW-1660954: Add support for DatetimeIndex.tz_localize/tz_convert. --- CHANGELOG.md | 1 + .../supported/datetime_index_supported.rst | 4 +- .../modin/plugin/_internal/timestamp_utils.py | 2 +- .../compiler/snowflake_query_compiler.py | 30 +++-- .../modin/plugin/extensions/datetime_index.py | 29 +++-- .../index/test_datetime_index_methods.py | 111 ++++++++++++++++++ 6 files changed, 157 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b19e909930e..d8efe82596f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -125,6 +125,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added support for `Series.dt.total_seconds` method. - Added support for `DataFrame.apply(axis=0)`. - Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`. +- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`. #### Improvements diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 68b1935da96..3afe671aee7 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -82,9 +82,9 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``snap`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | | +| ``tz_convert`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | | +| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index f8629e664f3..3b714087535 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -525,7 +525,7 @@ def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: The column after conversion to the specified timezone """ if tz is None: - return convert_timezone(pandas_lit("UTC"), column) + return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column)) else: if isinstance(tz, dt.tzinfo): tz_name = tz.tzname(None) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 38402e5b984..8ef3bdf9bee 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -16660,6 +16660,7 @@ def dt_tz_localize( tz: Union[str, tzinfo], ambiguous: str = "raise", nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Localize tz-naive to tz-aware. @@ -16667,39 +16668,50 @@ def dt_tz_localize( tz : str, pytz.timezone, optional ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise" nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise" + include_index: Whether to include the index columns in the operation. Returns: BaseQueryCompiler New QueryCompiler containing values with localized time zone. """ + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if not include_index: + method_name = "Series.dt.tz_localize" + else: + assert is_datetime64_any_dtype(dtype), "column must be datetime" + method_name = "DatetimeIndex.tz_localize" + if not isinstance(ambiguous, str) or ambiguous != "raise": - ErrorMessage.parameter_not_implemented_error( - "ambiguous", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if not isinstance(nonexistent, str) or nonexistent != "raise": - ErrorMessage.parameter_not_implemented_error( - "nonexistent", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_localize_column(column, tz) + lambda column: tz_localize_column(column, tz), + include_index, ) ) - def dt_tz_convert(self, tz: Union[str, tzinfo]) -> "SnowflakeQueryCompiler": + def dt_tz_convert( + self, + tz: Union[str, tzinfo], + include_index: bool = False, + ) -> "SnowflakeQueryCompiler": """ Convert time-series data to the specified time zone. Args: tz : str, pytz.timezone + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler containing values with converted time zone. """ return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_convert_column(column, tz) + lambda column: tz_convert_column(column, tz), + include_index, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index df136af1a34..38edb9f7bee 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -960,7 +960,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex: DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None) """ - @datetime_index_not_implemented() def tz_convert(self, tz) -> DatetimeIndex: """ Convert tz-aware Datetime Array/Index from one time zone to another. @@ -1025,8 +1024,14 @@ def tz_convert(self, tz) -> DatetimeIndex: '2014-08-01 09:00:00'], dtype='datetime64[ns]', freq='h') """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_convert( + tz, + include_index=True, + ) + ) - @datetime_index_not_implemented() def tz_localize( self, tz, @@ -1104,21 +1109,29 @@ def tz_localize( Localize DatetimeIndex in US/Eastern time zone: - >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP - >>> tz_aware # doctest: +SKIP - DatetimeIndex(['2018-03-01 09:00:00-05:00', - '2018-03-02 09:00:00-05:00', + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') + >>> tz_aware + DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00', '2018-03-03 09:00:00-05:00'], - dtype='datetime64[ns, US/Eastern]', freq=None) + dtype='datetime64[ns, UTC-05:00]', freq=None) With the ``tz=None``, we can remove the time zone information while keeping the local time (not converted to UTC): - >>> tz_aware.tz_localize(None) # doctest: +SKIP + >>> tz_aware.tz_localize(None) DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', '2018-03-03 09:00:00'], dtype='datetime64[ns]', freq=None) """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_localize( + tz, + ambiguous, + nonexistent, + include_index=True, + ) + ) def round( self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 143e1d74080..793485f97d6 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -7,6 +7,7 @@ import numpy as np import pandas as native_pd import pytest +import pytz import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker @@ -17,6 +18,46 @@ eval_snowpark_pandas_result, ) +timezones = pytest.mark.parametrize( + "tz", + [ + None, + # Use a subset of pytz.common_timezones containing a few timezones in each + *[ + param_for_one_tz + for tz in [ + "Africa/Abidjan", + "Africa/Timbuktu", + "America/Adak", + "America/Yellowknife", + "Antarctica/Casey", + "Asia/Dhaka", + "Asia/Manila", + "Asia/Shanghai", + "Atlantic/Stanley", + "Australia/Sydney", + "Canada/Pacific", + "Europe/Chisinau", + "Europe/Luxembourg", + "Indian/Christmas", + "Pacific/Chatham", + "Pacific/Wake", + "US/Arizona", + "US/Central", + "US/Eastern", + "US/Hawaii", + "US/Mountain", + "US/Pacific", + "UTC", + ] + for param_for_one_tz in ( + pytz.timezone(tz), + tz, + ) + ], + ], +) + @sql_count_checker(query_count=0) def test_datetime_index_construction(): @@ -233,6 +274,76 @@ def test_normalize(): ) +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_convert(tz): + native_index = native_pd.date_range( + start="2021-01-01", periods=5, freq="7h", tz="US/Eastern" + ) + native_index = native_index.append( + native_pd.DatetimeIndex([pd.NaT], tz="US/Eastern") + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_convert(tz).equals( + pd.DatetimeIndex(native_index.tz_convert(tz)) + ) + + +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_localize(tz): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_localize(tz).equals( + pd.DatetimeIndex(native_index.tz_localize(tz)) + ) + + +@pytest.mark.parametrize( + "ambiguous, nonexistent", + [ + ("infer", "raise"), + ("NaT", "raise"), + (np.array([True, True, False]), "raise"), + ("raise", "shift_forward"), + ("raise", "shift_backward"), + ("raise", "NaT"), + ("raise", pd.Timedelta("1h")), + ("infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_tz_localize_negative(ambiguous, nonexistent): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + with pytest.raises(NotImplementedError): + snow_index.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent) + + @pytest.mark.parametrize( "datetime_index_value", [ From 08ff293e35ed4054654e84ec3becc0c1e7d403a6 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Sep 2024 15:07:15 -0700 Subject: [PATCH 06/22] SNOW-1636678 add server side param for complexity bounds (#2273) --- .../compiler/large_query_breakdown.py | 19 ++++---- .../_internal/compiler/plan_compiler.py | 7 +-- .../_internal/compiler/telemetry_constants.py | 3 ++ src/snowflake/snowpark/_internal/telemetry.py | 23 ++++++++- src/snowflake/snowpark/session.py | 44 +++++++++++++++++ tests/integ/test_large_query_breakdown.py | 47 ++++++++----------- tests/integ/test_session.py | 26 ++++++++++ 7 files changed, 125 insertions(+), 44 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 836628345aa..8d16383a4ce 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -58,11 +58,6 @@ ) from snowflake.snowpark.session import Session -# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT -# in Snowflake. This is the limit where we start seeing compilation errors. -COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 -COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 - _logger = logging.getLogger(__name__) @@ -123,6 +118,12 @@ def __init__( self._query_generator = query_generator self.logical_plans = logical_plans self._parent_map = defaultdict(set) + self.complexity_score_lower_bound = ( + session.large_query_breakdown_complexity_bounds[0] + ) + self.complexity_score_upper_bound = ( + session.large_query_breakdown_complexity_bounds[1] + ) def apply(self) -> List[LogicalPlan]: if is_active_transaction(self.session): @@ -183,13 +184,13 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]: complexity_score = get_complexity_score(root.cumulative_node_complexity) _logger.debug(f"Complexity score for root {type(root)} is: {complexity_score}") - if complexity_score <= COMPLEXITY_SCORE_UPPER_BOUND: + if complexity_score <= self.complexity_score_upper_bound: # Skip optimization if the complexity score is within the upper bound. return [root] plans = [] # TODO: SNOW-1617634 Have a one pass algorithm to find the valid node for partitioning - while complexity_score > COMPLEXITY_SCORE_UPPER_BOUND: + while complexity_score > self.complexity_score_upper_bound: child = self._find_node_to_breakdown(root) if child is None: _logger.debug( @@ -277,7 +278,9 @@ def _is_node_valid_to_breakdown(self, node: LogicalPlan) -> Tuple[bool, int]: """ score = get_complexity_score(node.cumulative_node_complexity) valid_node = ( - COMPLEXITY_SCORE_LOWER_BOUND < score < COMPLEXITY_SCORE_UPPER_BOUND + self.complexity_score_lower_bound + < score + < self.complexity_score_upper_bound ) and self._is_node_pipeline_breaker(node) if valid_node: _logger.debug( diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 211b66820ec..3e6dba71be4 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -16,8 +16,6 @@ ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.compiler.large_query_breakdown import ( - COMPLEXITY_SCORE_LOWER_BOUND, - COMPLEXITY_SCORE_UPPER_BOUND, LargeQueryBreakdown, ) from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import ( @@ -128,10 +126,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: summary_value = { TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( - COMPLEXITY_SCORE_LOWER_BOUND, - COMPLEXITY_SCORE_UPPER_BOUND, - ), + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 223b6a1326f..be61a1ac924 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -11,6 +11,9 @@ class CompilationStageTelemetryField(Enum): "snowpark_large_query_breakdown_optimization_skipped" ) TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics" + TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = ( + "snowpark_large_query_breakdown_update_complexity_bounds" + ) # keys KEY_REASON = "reason" diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index aef60828334..025eb57c540 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -388,7 +388,7 @@ def send_sql_simplifier_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.SQL_SIMPLIFIER_ENABLED.value: True, + TelemetryField.SQL_SIMPLIFIER_ENABLED.value: sql_simplifier_enabled, }, } self.send(message) @@ -442,7 +442,7 @@ def send_large_query_breakdown_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: True, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: value, }, } self.send(message) @@ -516,3 +516,22 @@ def send_temp_table_cleanup_abnormal_exception_telemetry( }, } self.send(message) + + def send_large_query_breakdown_update_complexity_bounds( + self, session_id: int, lower_bound: int, upper_bound: int + ): + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_DATA.value: { + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( + lower_bound, + upper_bound, + ), + }, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 8ffd4081473..e66155e01ea 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -223,6 +223,16 @@ _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = ( "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION" ) +_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND = ( + "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND" +) +_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = ( + "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND" +) +# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT +# in Snowflake. This is the limit where we start seeing compilation errors. +DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 +DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None @@ -577,6 +587,18 @@ def __init__( _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION, False ) ) + # The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT + # in Snowflake. This is the limit where we start seeing compilation errors. + self._large_query_breakdown_complexity_bounds: Tuple[int, int] = ( + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + ), + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + ), + ) self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None @@ -689,6 +711,10 @@ def auto_clean_up_temp_table_enabled(self) -> bool: def large_query_breakdown_enabled(self) -> bool: return self._large_query_breakdown_enabled + @property + def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]: + return self._large_query_breakdown_complexity_bounds + @property def custom_package_usage_config(self) -> Dict: """Get or set configuration parameters related to usage of custom Python packages in Snowflake. @@ -799,6 +825,24 @@ def large_query_breakdown_enabled(self, value: bool) -> None: "value for large_query_breakdown_enabled must be True or False!" ) + @large_query_breakdown_complexity_bounds.setter + def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: + """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" + + if len(value) != 2: + raise ValueError( + f"Expecting a tuple of two integers. Got a tuple of length {len(value)}" + ) + if value[0] >= value[1]: + raise ValueError( + f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})" + ) + self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( + self._session_id, value[0], value[1] + ) + + self._large_query_breakdown_complexity_bounds = value + @custom_package_usage_config.setter @experimental_parameter(version="1.6.0") def custom_package_usage_config(self, config: Dict) -> None: diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index e42a504a976..bdd780ea69e 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -9,9 +9,13 @@ import pytest from snowflake.snowpark._internal.analyzer import analyzer -from snowflake.snowpark._internal.compiler import large_query_breakdown from snowflake.snowpark.functions import col, lit, sum_distinct, when_matched from snowflake.snowpark.row import Row +from snowflake.snowpark.session import ( + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + Session, +) from tests.utils import Utils pytestmark = [ @@ -22,9 +26,6 @@ ) ] -DEFAULT_LOWER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND -DEFAULT_UPPER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND - @pytest.fixture(autouse=True) def large_query_df(session): @@ -50,20 +51,24 @@ def setup(session): is_query_compilation_stage_enabled = session._query_compilation_stage_enabled session._query_compilation_stage_enabled = True session._large_query_breakdown_enabled = True + set_bounds(session, 300, 600) yield session._query_compilation_stage_enabled = is_query_compilation_stage_enabled session._cte_optimization_enabled = cte_optimization_enabled session._large_query_breakdown_enabled = large_query_breakdown_enabled - reset_bounds() + reset_bounds(session) -def set_bounds(lower_bound: int, upper_bound: int): - large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND = lower_bound - large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND = upper_bound +def set_bounds(session: Session, lower_bound: int, upper_bound: int): + session._large_query_breakdown_complexity_bounds = (lower_bound, upper_bound) -def reset_bounds(): - set_bounds(DEFAULT_LOWER_BOUND, DEFAULT_UPPER_BOUND) +def reset_bounds(session: Session): + set_bounds( + session, + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + ) def check_result_with_and_without_breakdown(session, df): @@ -82,8 +87,6 @@ def check_result_with_and_without_breakdown(session, df): def test_no_valid_nodes_found(session, large_query_df, caplog): """Test large query breakdown works with default bounds""" - set_bounds(300, 600) - base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -104,7 +107,6 @@ def test_no_valid_nodes_found(session, large_query_df, caplog): def test_large_query_breakdown_with_cte_optimization(session): """Test large query breakdown works with cte optimized plan""" - set_bounds(300, 600) session._cte_optimization_enabled = True df0 = session.sql("select 2 as b, 32 as c") df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1) @@ -131,7 +133,6 @@ def test_large_query_breakdown_with_cte_optimization(session): def test_save_as_table(session, large_query_df): - set_bounds(300, 600) table_name = Utils.random_table_name() with session.query_history() as history: large_query_df.write.save_as_table(table_name, mode="overwrite") @@ -146,7 +147,6 @@ def test_save_as_table(session, large_query_df): def test_update_delete_merge(session, large_query_df): - set_bounds(300, 600) session._large_query_breakdown_enabled = True table_name = Utils.random_table_name() df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"]) @@ -186,7 +186,6 @@ def test_update_delete_merge(session, large_query_df): def test_copy_into_location(session, large_query_df): - set_bounds(300, 600) remote_file_path = f"{session.get_session_stage()}/df.parquet" with session.query_history() as history: large_query_df.write.copy_into_location( @@ -204,7 +203,6 @@ def test_copy_into_location(session, large_query_df): def test_pivot_unpivot(session): - set_bounds(300, 600) session.sql( """create or replace temp table monthly_sales(A int, B int, month text) as select * from values @@ -243,7 +241,6 @@ def test_pivot_unpivot(session): def test_sort(session): - set_bounds(300, 600) base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -276,7 +273,6 @@ def test_sort(session): def test_multiple_query_plan(session, large_query_df): - set_bounds(300, 600) original_threshold = analyzer.ARRAY_BIND_THRESHOLD try: analyzer.ARRAY_BIND_THRESHOLD = 2 @@ -314,7 +310,6 @@ def test_multiple_query_plan(session, large_query_df): def test_optimization_skipped_with_transaction(session, large_query_df, caplog): """Test large query breakdown is skipped when transaction is enabled""" - set_bounds(300, 600) session.sql("begin").collect() assert Utils.is_active_transaction(session) with caplog.at_level(logging.DEBUG): @@ -330,7 +325,6 @@ def test_optimization_skipped_with_transaction(session, large_query_df, caplog): def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): """Test large query breakdown is skipped plan is a view or dynamic table""" - set_bounds(300, 600) source_table = Utils.random_table_name() table_name = Utils.random_table_name() view_name = Utils.random_view_name() @@ -360,7 +354,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): def test_async_job_with_large_query_breakdown(session, large_query_df): """Test large query breakdown gives same result for async and non-async jobs""" - set_bounds(300, 600) job = large_query_df.collect(block=False) result = job.result() assert result == large_query_df.collect() @@ -376,8 +369,6 @@ def test_async_job_with_large_query_breakdown(session, large_query_df): def test_add_parent_plan_uuid_to_statement_params(session, large_query_df): - set_bounds(300, 600) - with patch.object( session._conn, "run_query", wraps=session._conn.run_query ) as patched_run_query: @@ -400,7 +391,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): """Test complexity bounds affect number of partitions. Also test that when partitions are added, drop table queries are added. """ - set_bounds(300, 600) + set_bounds(session, 300, 600) assert len(large_query_df.queries["queries"]) == 2 assert len(large_query_df.queries["post_actions"]) == 1 assert large_query_df.queries["queries"][0].startswith( @@ -410,7 +401,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): "DROP TABLE If EXISTS" ) - set_bounds(300, 412) + set_bounds(session, 300, 412) assert len(large_query_df.queries["queries"]) == 3 assert len(large_query_df.queries["post_actions"]) == 2 assert large_query_df.queries["queries"][0].startswith( @@ -426,11 +417,11 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): "DROP TABLE If EXISTS" ) - set_bounds(0, 300) + set_bounds(session, 0, 300) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 - reset_bounds() + reset_bounds(session) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index df0afc1099b..21e77883338 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -5,6 +5,7 @@ import os from functools import partial +from unittest.mock import patch import pytest @@ -719,6 +720,31 @@ def test_eliminate_numeric_sql_value_cast_optimization_enabled_on_session( new_session.eliminate_numeric_sql_value_cast_enabled = None +def test_large_query_breakdown_complexity_bounds(session): + original_bounds = session.large_query_breakdown_complexity_bounds + try: + with pytest.raises(ValueError, match="Expecting a tuple of two integers"): + session.large_query_breakdown_complexity_bounds = (1, 2, 3) + + with pytest.raises( + ValueError, match="Expecting a tuple of lower and upper bound" + ): + session.large_query_breakdown_complexity_bounds = (3, 2) + + with patch.object( + session._conn._telemetry_client, + "send_large_query_breakdown_update_complexity_bounds", + ) as patch_send: + session.large_query_breakdown_complexity_bounds = (1, 2) + assert session.large_query_breakdown_complexity_bounds == (1, 2) + assert patch_send.call_count == 1 + assert patch_send.call_args[0][0] == session.session_id + assert patch_send.call_args[0][1] == 1 + assert patch_send.call_args[0][2] == 2 + finally: + session.large_query_breakdown_complexity_bounds = original_bounds + + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_create_session_from_default_config_file(monkeypatch, db_parameters): import tomlkit From a3586c8ff41ab004d4662e39831c322d5843c8ac Mon Sep 17 00:00:00 2001 From: Naren Krishna Date: Fri, 13 Sep 2024 15:18:40 -0700 Subject: [PATCH 07/22] SNOW-1662105, SNOW-1662657: Support `by`, `left_by`, `right_by` for `pd.merge_asof` (#2284) SNOW-1662105, SNOW-1662657 This PR refactors `join_utils.py` to make `left_on` and `right_on` optional arguments, as they are not required for "cross" or "asof" joins. It also support `by`, `left_by`, `right_by` for `pd.merge_asof`. --------- Signed-off-by: Naren Krishna --- CHANGELOG.md | 1 + .../modin/supported/general_supported.rst | 3 +- .../modin/plugin/_internal/cut_utils.py | 2 - .../modin/plugin/_internal/indexing_utils.py | 2 - .../modin/plugin/_internal/join_utils.py | 123 +++++++++++++----- .../plugin/_internal/ordered_dataframe.py | 31 +++-- .../modin/plugin/_internal/resample_utils.py | 2 - .../compiler/snowflake_query_compiler.py | 96 +++++++++----- tests/integ/modin/test_merge_asof.py | 63 +++++---- 9 files changed, 212 insertions(+), 111 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8efe82596f..0bd719dcb8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - Added support for `TimedeltaIndex.mean` method. - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. +- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. ## 1.22.1 (2024-09-11) diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 797ef3bbd59..95d9610202b 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -38,8 +38,7 @@ Data manipulations +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. | -| | | , ``left_index``, ``right_index``| | +| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. | | | | , ``suffixes``, ``tolerance`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge_ordered`` | N | | | diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py index 882dc79d2a8..4eaf98d9b29 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -189,8 +189,6 @@ def compute_bin_indices( values_frame, cuts_frame, how="asof", - left_on=[], - right_on=[], left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0], right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0], match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index c2c224e404c..6207bd2399a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame( key, count_frame, "cross", - left_on=[], - right_on=[], inherit_join_index=InheritJoinIndex.FROM_LEFT, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 79f063b9ece..d07211dbcf5 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple): result_column_mapper: JoinOrAlignResultColumnMapper +def assert_snowpark_pandas_types_match( + left: InternalFrame, + right: InternalFrame, + left_join_identifiers: list[str], + right_join_identifiers: list[str], +) -> None: + """ + If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised. + + Args: + left: An internal frame to use on left side of join. + right: An internal frame to use on right side of join. + left_join_identifiers: List of snowflake identifiers to check types from 'left' frame. + right_join_identifiers: List of snowflake identifiers to check types from 'right' frame. + left_identifiers and right_identifiers must be lists of equal length. + + Returns: None + + Raises: ValueError + """ + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in left_join_identifiers + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in right_join_identifiers + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_join_identifiers[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = ( + rt + if rt is not None + else right.get_snowflake_type(right_join_identifiers[i]) + ) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + def join( left: InternalFrame, right: InternalFrame, how: JoinTypeLit, - left_on: list[str], - right_on: list[str], + left_on: Optional[list[str]] = None, + right_on: Optional[list[str]] = None, left_match_col: Optional[str] = None, right_match_col: Optional[str] = None, match_comparator: Optional[MatchComparator] = None, @@ -161,40 +206,48 @@ def join( include mapping for index + data columns, ordering columns and row position column if exists. """ - assert len(left_on) == len( - right_on - ), "left_on and right_on must be of same length or both be None" - if join_key_coalesce_config is not None: - assert len(join_key_coalesce_config) == len( - left_on - ), "join_key_coalesce_config must be of same length as left_on and right_on" assert how in get_args( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" - def assert_snowpark_pandas_types_match() -> None: - """If Snowpark pandas types do not match, then a ValueError will be raised.""" - left_types = [ - left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in left_on - ] - right_types = [ - right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in right_on - ] - for i, (lt, rt) in enumerate(zip(left_types, right_types)): - if lt != rt: - left_on_id = left_on[i] - idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) - key = left.data_column_pandas_labels[idx] - lt = lt if lt is not None else left.get_snowflake_type(left_on_id) - rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) - raise ValueError( - f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " - f"If you wish to proceed you should use pd.concat" - ) + left_on = left_on or [] + right_on = right_on or [] + assert len(left_on) == len( + right_on + ), "left_on and right_on must be of same length or both be None" - assert_snowpark_pandas_types_match() + if how == "asof": + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + left_join_key = [left_match_col] + right_join_key = [right_match_col] + left_join_key.extend(left_on) + right_join_key.extend(right_on) + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key" + else: + left_join_key = left_on + right_join_key = right_on + assert ( + left_match_col is None + and right_match_col is None + and match_comparator is None + ), f"match condition should not be provided for {how} join" + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "join_key_coalesce_config must be of same length as left_on and right_on" + + assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key) # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None: match_comparator=match_comparator, how=how, ) - return _create_internal_frame_with_join_or_align_result( joined_ordered_dataframe, left, right, how, - left_on, - right_on, + left_join_key, + right_join_key, sort, join_key_coalesce_config, inherit_join_index, @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None: ) elif self._how == "right": ordering_column_identifiers = mapped_right_on + elif self._how == "asof": + # Order only by the left match_condition column + ordering_column_identifiers = [mapped_left_on[0]] else: # left join, inner join, left align, coalesce align ordering_column_identifiers = mapped_left_on diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index f7ae87c2a5d..91537d98e30 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -1197,22 +1197,29 @@ def join( # get the new mapped right on identifier right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols] - # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' - on = None - for left_col, right_col in zip(left_on_cols, right_on_cols): - eq = Column(left_col).equal_null(Column(right_col)) - on = eq if on is None else on & eq - if how == "asof": - assert left_match_col, "left_match_col was not provided to ASOF Join" + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" left_match_col = Column(left_match_col) # Get the new mapped right match condition identifier - assert right_match_col, "right_match_col was not provided to ASOF Join" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" right_match_col = Column(right_identifiers_rename_map[right_match_col]) # ASOF Join requires the use of match_condition - assert match_comparator, "match_comparator was not provided to ASOF Join" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).__eq__(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right=right_snowpark_dataframe_ref.snowpark_dataframe, + on=on, how=how, match_condition=getattr(left_match_col, match_comparator.value)( right_match_col @@ -1224,6 +1231,12 @@ def join( right_snowpark_dataframe_ref.snowpark_dataframe, how=how ) else: + # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).equal_null(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right_snowpark_dataframe_ref.snowpark_dataframe, on, how ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index de83e0429bf..ba8ceedec5e 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -649,8 +649,6 @@ def perform_asof_join_on_frame( left=preserving_frame, right=referenced_frame, how="asof", - left_on=[], - right_on=[], left_match_col=left_timecol_snowflake_quoted_identifier, right_match_col=right_timecol_snowflake_quoted_identifier, match_comparator=( diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 8ef3bdf9bee..48f91ab40dd 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -7381,28 +7381,34 @@ def merge_asof( SnowflakeQueryCompiler """ # TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation - if ( - by - or left_by - or right_by - or left_index - or right_index - or tolerance - or suffixes != ("_x", "_y") - ): + if left_index or right_index or tolerance or suffixes != ("_x", "_y"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method does not currently support parameters " - + "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - + "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ) if direction not in ("backward", "forward"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ) + if direction == "backward": + match_comparator = ( + MatchComparator.GREATER_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.GREATER_THAN + ) + else: + match_comparator = ( + MatchComparator.LESS_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.LESS_THAN + ) + left_frame = self._modin_frame right_frame = right._modin_frame - left_keys, right_keys = join_utils.get_join_keys( + # Get the left and right matching key and quoted identifier corresponding to the match_condition + # There will only be matching key/identifier for each table as there is only a single match condition + left_match_keys, right_match_keys = join_utils.get_join_keys( left=left_frame, right=right_frame, on=on, @@ -7411,42 +7417,62 @@ def merge_asof( left_index=left_index, right_index=right_index, ) - left_match_col = ( + left_match_identifier = ( left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - left_keys + left_match_keys )[0][0] ) - right_match_col = ( + right_match_identifier = ( right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - right_keys + right_match_keys )[0][0] ) - - if direction == "backward": - match_comparator = ( - MatchComparator.GREATER_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.GREATER_THAN + coalesce_config = join_utils.get_coalesce_config( + left_keys=left_match_keys, + right_keys=right_match_keys, + external_join_keys=[], + ) + + # Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition + if by or (left_by and right_by): + left_on_keys, right_on_keys = join_utils.get_join_keys( + left=left_frame, + right=right_frame, + on=by, + left_on=left_by, + right_on=right_by, + ) + left_on_identifiers = [ + ids[0] + for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + left_on_keys + ) + ] + right_on_identifiers = [ + ids[0] + for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + right_on_keys + ) + ] + coalesce_config.extend( + join_utils.get_coalesce_config( + left_keys=left_on_keys, + right_keys=right_on_keys, + external_join_keys=[], + ) ) else: - match_comparator = ( - MatchComparator.LESS_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.LESS_THAN - ) - - coalesce_config = join_utils.get_coalesce_config( - left_keys=left_keys, right_keys=right_keys, external_join_keys=[] - ) + left_on_identifiers = [] + right_on_identifiers = [] joined_frame, _ = join_utils.join( left=left_frame, right=right_frame, + left_on=left_on_identifiers, + right_on=right_on_identifiers, how="asof", - left_on=[left_match_col], - right_on=[right_match_col], - left_match_col=left_match_col, - right_match_col=right_match_col, + left_match_col=left_match_identifier, + right_match_col=right_match_identifier, match_comparator=match_comparator, join_key_coalesce_config=coalesce_config, sort=True, diff --git a/tests/integ/modin/test_merge_asof.py b/tests/integ/modin/test_merge_asof.py index 681d339da90..51dda7889e7 100644 --- a/tests/integ/modin/test_merge_asof.py +++ b/tests/integ/modin/test_merge_asof.py @@ -105,6 +105,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.072"), pd.Timestamp("2016-05-25 13:30:00.075"), ], + "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], } @@ -118,6 +119,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.048"), pd.Timestamp("2016-05-25 13:30:00.048"), ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], "price": [51.95, 51.95, 720.77, 720.92, 98.0], "quantity": [75, 155, 100, 100, 100], } @@ -229,14 +231,39 @@ def test_merge_asof_left_right_on( assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) +@pytest.mark.parametrize("by", ["ticker", ["ticker"]]) @sql_count_checker(query_count=1, join_count=1) -def test_merge_asof_timestamps(left_right_timestamp_data): +def test_merge_asof_by(left_right_timestamp_data, by): left_native_df, right_native_df = left_right_timestamp_data left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by=by + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by=by) + assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) + + +@pytest.mark.parametrize( + "left_by, right_by", + [ + ("ticker", "ticker"), + (["ticker", "bid"], ["ticker", "price"]), + ], +) +@sql_count_checker(query_count=1, join_count=1) +def test_merge_asof_left_right_by(left_right_timestamp_data, left_by, right_by): + left_native_df, right_native_df = left_right_timestamp_data + left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( + right_native_df + ) + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", left_by=left_by, right_by=right_by + ) + snow_output = pd.merge_asof( + left_snow_df, right_snow_df, on="time", left_by=left_by, right_by=right_by + ) assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -248,8 +275,10 @@ def test_merge_asof_date(left_right_timestamp_data): left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by="ticker" + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by="ticker") assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -360,9 +389,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): with pytest.raises( NotImplementedError, match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ), ): pd.merge_asof( @@ -372,19 +399,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" - ), - ): - pd.merge_asof( - left_snow_df, right_snow_df, on="time", left_by="price", right_by="quantity" - ) - with pytest.raises( - NotImplementedError, - match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof(left_snow_df, right_snow_df, left_index=True, right_index=True) @@ -392,8 +407,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof( @@ -406,8 +420,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof( From 7a5e6bbdfd9e44448116ffacce4595e88ad41c0e Mon Sep 17 00:00:00 2001 From: Naresh Kumar <113932371+sfc-gh-nkumar@users.noreply.github.com> Date: Fri, 13 Sep 2024 17:16:32 -0700 Subject: [PATCH 08/22] SNOW-1320674: Add tests to verify loc raises KeyError for invalid labels (#2293) SNOW-1320674: This bug is already fixed in main brach. Adding a explicit test case before closing it. --- tests/integ/modin/frame/test_loc.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py index be51b8c9ae6..105bf475f3a 100644 --- a/tests/integ/modin/frame/test_loc.py +++ b/tests/integ/modin/frame/test_loc.py @@ -4072,3 +4072,22 @@ def test_df_loc_get_with_timedelta_and_none_key(): # Compare with an empty DataFrame, since native pandas raises a KeyError. expected_df = native_pd.DataFrame() assert_frame_equal(snow_df.loc[None], expected_df, check_column_type=False) + + +@sql_count_checker(query_count=0) +def test_df_loc_invalid_key(): + # Bug fix: SNOW-1320674 + native_df = native_pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + snow_df = pd.DataFrame(native_df) + + def op(df): + df["C"] = df["A"] / df["D"] + + eval_snowpark_pandas_result( + snow_df, + native_df, + op, + expect_exception=True, + expect_exception_type=KeyError, + expect_exception_match="D", + ) From 64ced96ea9519d2bb5c67d05cd6a0986848b7ace Mon Sep 17 00:00:00 2001 From: Jonathan Shi <149419494+sfc-gh-joshi@users.noreply.github.com> Date: Fri, 13 Sep 2024 17:54:49 -0700 Subject: [PATCH 09/22] SNOW-1063346: Remove modin/pandas/dataframe.py (#2223) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1063346 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. This PR removes dataframe.py (the Snowpark pandas one, not the Snowpark Python one), following #2167 and #2205. Once more, preserved overrides are given a reason in code comments. The following implemented methods have been added to `dataframe_overrides.py`: - __init__ - __dataframe__ - __and__, __rand__, __or__, __ror__ - apply, applymap - columns - corr - dropna - fillna - groupby - info - insert - isin - join - mask - melt - merge - replace - rename - pivot_table - pow - rpow - select_dtypes - set_axis - set_index - shape - squeeze - sum - stack - transpose - unstack - value_counts - where - iterrows - itertuples - __repr__ - _repr_html_ - _to_datetime - _to_pandas - __setitem__ --- .../snowpark/modin/pandas/__init__.py | 14 +- .../modin/pandas/api/extensions/__init__.py | 7 +- .../modin/pandas/api/extensions/extensions.py | 43 - .../snowpark/modin/pandas/dataframe.py | 3511 ----------------- .../snowpark/modin/pandas/general.py | 20 +- .../snowpark/modin/pandas/indexing.py | 2 +- src/snowflake/snowpark/modin/pandas/io.py | 4 +- .../modin/pandas/snow_partition_iterator.py | 3 +- src/snowflake/snowpark/modin/pandas/utils.py | 12 +- .../snowpark/modin/plugin/__init__.py | 15 +- .../modin/plugin/_internal/telemetry.py | 4 +- .../snowpark/modin/plugin/_internal/utils.py | 2 +- .../compiler/snowflake_query_compiler.py | 2 +- .../snowpark/modin/plugin/docstrings/base.py | 1 + .../modin/plugin/docstrings/dataframe.py | 2 +- .../modin/plugin/extensions/base_overrides.py | 26 +- .../plugin/extensions/dataframe_overrides.py | 2223 ++++++++++- .../snowpark/modin/plugin/extensions/index.py | 2 +- .../modin/plugin/extensions/pd_extensions.py | 2 +- .../modin/plugin/extensions/pd_overrides.py | 2 +- .../plugin/extensions/series_extensions.py | 1 - .../plugin/extensions/series_overrides.py | 159 +- .../plugin/extensions/timedelta_index.py | 2 +- .../modin/plugin/utils/frontend_constants.py | 14 + src/snowflake/snowpark/modin/utils.py | 2 +- tests/integ/modin/frame/test_info.py | 4 +- tests/integ/modin/test_classes.py | 8 +- tests/integ/modin/test_telemetry.py | 29 +- tests/unit/modin/modin/test_envvars.py | 1 + 29 files changed, 2339 insertions(+), 3778 deletions(-) delete mode 100644 src/snowflake/snowpark/modin/pandas/dataframe.py diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 6960d0eb629..8f9834630b7 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -88,13 +88,13 @@ # TODO: SNOW-851745 make sure add all Snowpark pandas API general functions from modin.pandas import plotting # type: ignore[import] +from modin.pandas.dataframe import DataFrame from modin.pandas.series import Series from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, ) -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.general import ( bdate_range, concat, @@ -185,10 +185,8 @@ modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP) -# For any method defined on Series/DF, add telemetry to it if it: -# 1. Is defined directly on an upstream class -# 2. The method name does not start with an _, or is in TELEMETRY_PRIVATE_METHODS - +# For any method defined on Series/DF, add telemetry to it if the method name does not start with an +# _, or the method is in TELEMETRY_PRIVATE_METHODS. This includes methods defined as an extension/override. for attr_name in dir(Series): # Since Series is defined in upstream Modin, all of its members were either defined upstream # or overridden by extension. @@ -197,11 +195,9 @@ try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name)) ) - -# TODO: SNOW-1063346 -# Since we still use the vendored version of DataFrame and the overrides for the top-level -# namespace haven't been performed yet, we need to set properties on the vendored version for attr_name in dir(DataFrame): + # Since DataFrame is defined in upstream Modin, all of its members were either defined upstream + # or overridden by extension. if not attr_name.startswith("_") or attr_name in TELEMETRY_PRIVATE_METHODS: register_dataframe_accessor(attr_name)( try_add_telemetry_to_attribute(attr_name, getattr(DataFrame, attr_name)) diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py index 6a34f50e42a..47d44835fe4 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py @@ -19,9 +19,12 @@ # existing code originally distributed by the Modin project, under the Apache License, # Version 2.0. -from modin.pandas.api.extensions import register_series_accessor +from modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) -from .extensions import register_dataframe_accessor, register_pd_accessor +from .extensions import register_pd_accessor __all__ = [ "register_dataframe_accessor", diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py index 45896292e74..05424c92072 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py @@ -86,49 +86,6 @@ def decorator(new_attr: Any): return decorator -def register_dataframe_accessor(name: str): - """ - Registers a dataframe attribute with the name provided. - This is a decorator that assigns a new attribute to DataFrame. It can be used - with the following syntax: - ``` - @register_dataframe_accessor("new_method") - def my_new_dataframe_method(*args, **kwargs): - # logic goes here - return - ``` - The new attribute can then be accessed with the name provided: - ``` - df.new_method(*my_args, **my_kwargs) - ``` - - If you want a property accessor, you must annotate with @property - after the call to this function: - ``` - @register_dataframe_accessor("new_prop") - @property - def my_new_dataframe_property(*args, **kwargs): - return _prop - ``` - - Parameters - ---------- - name : str - The name of the attribute to assign to DataFrame. - Returns - ------- - decorator - Returns the decorator function. - """ - import snowflake.snowpark.modin.pandas as pd - - return _set_attribute_on_obj( - name, - pd.dataframe._DATAFRAME_EXTENSIONS_, - pd.dataframe.DataFrame, - ) - - def register_pd_accessor(name: str): """ Registers a pd namespace attribute with the name provided. diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py deleted file mode 100644 index 83893e83e9c..00000000000 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ /dev/null @@ -1,3511 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -# Licensed to Modin Development Team under one or more contributor license agreements. -# See the NOTICE file distributed with this work for additional information regarding -# copyright ownership. The Modin Development Team licenses this file to you under the -# Apache License, Version 2.0 (the "License"); you may not use this file except in -# compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under -# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific language -# governing permissions and limitations under the License. - -# Code in this file may constitute partial or total reimplementation, or modification of -# existing code originally distributed by the Modin project, under the Apache License, -# Version 2.0. - -"""Module houses ``DataFrame`` class, that is distributed version of ``pandas.DataFrame``.""" - -from __future__ import annotations - -import collections -import datetime -import functools -import itertools -import re -import sys -import warnings -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from logging import getLogger -from typing import IO, Any, Callable, Literal - -import numpy as np -import pandas -from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor -from modin.pandas.base import BasePandasDataset - -# from . import _update_engine -from modin.pandas.iterator import PartitionIterator -from modin.pandas.series import Series -from pandas._libs.lib import NoDefault, no_default -from pandas._typing import ( - AggFuncType, - AnyArrayLike, - Axes, - Axis, - CompressionOptions, - FilePath, - FillnaOptions, - IgnoreRaise, - IndexLabel, - Level, - PythonFuncType, - Renamer, - Scalar, - StorageOptions, - Suffixes, - WriteBuffer, -) -from pandas.core.common import apply_if_callable, is_bool_indexer -from pandas.core.dtypes.common import ( - infer_dtype_from_object, - is_bool_dtype, - is_dict_like, - is_list_like, - is_numeric_dtype, -) -from pandas.core.dtypes.inference import is_hashable, is_integer -from pandas.core.indexes.frozen import FrozenList -from pandas.io.formats.printing import pprint_thing -from pandas.util._validators import validate_bool_kwarg - -from snowflake.snowpark.modin import pandas as pd -from snowflake.snowpark.modin.pandas.groupby import ( - DataFrameGroupBy, - validate_groupby_args, -) -from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( - SnowparkPandasRowPartitionIterator, -) -from snowflake.snowpark.modin.pandas.utils import ( - create_empty_native_pandas_frame, - from_non_pandas, - from_pandas, - is_scalar, - raise_if_native_pandas_objects, - replace_external_data_keys_with_empty_pandas_series, - replace_external_data_keys_with_query_compiler, -) -from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated -from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike -from snowflake.snowpark.modin.plugin.utils.error_message import ( - ErrorMessage, - dataframe_not_implemented, -) -from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP -from snowflake.snowpark.modin.plugin.utils.warning_message import ( - SET_DATAFRAME_ATTRIBUTE_WARNING, - WarningMessage, -) -from snowflake.snowpark.modin.utils import _inherit_docstrings, hashable, to_pandas -from snowflake.snowpark.udf import UserDefinedFunction - -logger = getLogger(__name__) - -DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( - "Currently do not support Series or list-like keys with range-like values" -) - -DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( - "Currently do not support assigning a slice value as if it's a scalar value" -) - -DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( - "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " - "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " - "can work on the entire DataFrame in one shot." -) - -# Dictionary of extensions assigned to this class -_DATAFRAME_EXTENSIONS_ = {} - - -@_inherit_docstrings( - pandas.DataFrame, - excluded=[ - pandas.DataFrame.flags, - pandas.DataFrame.cov, - pandas.DataFrame.merge, - pandas.DataFrame.reindex, - pandas.DataFrame.to_parquet, - pandas.DataFrame.fillna, - ], - apilink="pandas.DataFrame", -) -class DataFrame(BasePandasDataset): - _pandas_class = pandas.DataFrame - - def __init__( - self, - data=None, - index=None, - columns=None, - dtype=None, - copy=None, - query_compiler=None, - ) -> None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Siblings are other dataframes that share the same query compiler. We - # use this list to update inplace when there is a shallow copy. - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native - - self._siblings = [] - - # Engine.subscribe(_update_engine) - if isinstance(data, (DataFrame, Series)): - self._query_compiler = data._query_compiler.copy() - if index is not None and any(i not in data.index for i in index): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if isinstance(data, Series): - # We set the column name if it is not in the provided Series - if data.name is None: - self.columns = [0] if columns is None else columns - # If the columns provided are not in the named Series, pandas clears - # the DataFrame and sets columns to the columns provided. - elif columns is not None and data.name not in columns: - self._query_compiler = from_pandas( - self.__constructor__(columns=columns) - )._query_compiler - if index is not None: - self._query_compiler = data.loc[index]._query_compiler - elif columns is None and index is None: - data._add_sibling(self) - else: - if columns is not None and any(i not in data.columns for i in columns): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if index is None: - index = slice(None) - if columns is None: - columns = slice(None) - self._query_compiler = data.loc[index, columns]._query_compiler - - # Check type of data and use appropriate constructor - elif query_compiler is None: - distributed_frame = from_non_pandas(data, index, columns, dtype) - if distributed_frame is not None: - self._query_compiler = distributed_frame._query_compiler - return - - if isinstance(data, pandas.Index): - pass - elif is_list_like(data) and not is_dict_like(data): - old_dtype = getattr(data, "dtype", None) - values = [ - obj._to_pandas() if isinstance(obj, Series) else obj for obj in data - ] - if isinstance(data, np.ndarray): - data = np.array(values, dtype=old_dtype) - else: - try: - data = type(data)(values, dtype=old_dtype) - except TypeError: - data = values - elif is_dict_like(data) and not isinstance( - data, (pandas.Series, Series, pandas.DataFrame, DataFrame) - ): - if columns is not None: - data = {key: value for key, value in data.items() if key in columns} - - if len(data) and all(isinstance(v, Series) for v in data.values()): - from .general import concat - - new_qc = concat( - data.values(), axis=1, keys=data.keys() - )._query_compiler - - if dtype is not None: - new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) - if index is not None: - new_qc = new_qc.reindex( - axis=0, labels=try_convert_index_to_native(index) - ) - if columns is not None: - new_qc = new_qc.reindex( - axis=1, labels=try_convert_index_to_native(columns) - ) - - self._query_compiler = new_qc - return - - data = { - k: v._to_pandas() if isinstance(v, Series) else v - for k, v in data.items() - } - pandas_df = pandas.DataFrame( - data=try_convert_index_to_native(data), - index=try_convert_index_to_native(index), - columns=try_convert_index_to_native(columns), - dtype=dtype, - copy=copy, - ) - self._query_compiler = from_pandas(pandas_df)._query_compiler - else: - self._query_compiler = query_compiler - - def __repr__(self): - """ - Return a string representation for a particular ``DataFrame``. - - Returns - ------- - str - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - num_rows = pandas.get_option("display.max_rows") or len(self) - # see _repr_html_ for comment, allow here also all column behavior - num_cols = pandas.get_option("display.max_columns") or len(self.columns) - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") - result = repr(repr_df) - - # if truncated, add shape information - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # The split here is so that we don't repr pandas row lengths. - return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( - row_count, col_count - ) - else: - return result - - def _repr_html_(self): # pragma: no cover - """ - Return a html representation for a particular ``DataFrame``. - - Returns - ------- - str - - Notes - ----- - Supports pandas `display.max_rows` and `display.max_columns` options. - """ - num_rows = pandas.get_option("display.max_rows") or 60 - # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow - # here value=0 which means display all columns. - num_cols = pandas.get_option("display.max_columns") - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols) - result = repr_df._repr_html_() - - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # We split so that we insert our correct dataframe dimensions. - return ( - result.split("

")[0] - + f"

{row_count} rows × {col_count} columns

\n" - ) - else: - return result - - def _get_columns(self) -> pandas.Index: - """ - Get the columns for this Snowpark pandas ``DataFrame``. - - Returns - ------- - Index - The all columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.columns - - def _set_columns(self, new_columns: Axes) -> None: - """ - Set the columns for this Snowpark pandas ``DataFrame``. - - Parameters - ---------- - new_columns : - The new columns to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_inplace( - new_query_compiler=self._query_compiler.set_columns(new_columns) - ) - - columns = property(_get_columns, _set_columns) - - @property - def ndim(self) -> int: - return 2 - - def drop_duplicates( - self, subset=None, keep="first", inplace=False, ignore_index=False - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - """ - Return ``DataFrame`` with duplicate rows removed. - """ - return super().drop_duplicates( - subset=subset, keep=keep, inplace=inplace, ignore_index=ignore_index - ) - - def dropna( - self, - *, - axis: Axis = 0, - how: str | NoDefault = no_default, - thresh: int | NoDefault = no_default, - subset: IndexLabel = None, - inplace: bool = False, - ): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super()._dropna( - axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace - ) - - @property - def dtypes(self): # noqa: RT01, D200 - """ - Return the dtypes in the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.dtypes - - def duplicated( - self, subset: Hashable | Sequence[Hashable] = None, keep: DropKeep = "first" - ): - """ - Return boolean ``Series`` denoting duplicate rows. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - df = self[subset] if subset is not None else self - new_qc = df._query_compiler.duplicated(keep=keep) - duplicates = self._reduce_dimension(new_qc) - # remove Series name which was assigned automatically by .apply in QC - # this is pandas behavior, i.e., if duplicated result is a series, no name is returned - duplicates.name = None - return duplicates - - @property - def empty(self) -> bool: - """ - Indicate whether ``DataFrame`` is empty. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self.columns) == 0 or len(self) == 0 - - @property - def axes(self): - """ - Return a list representing the axes of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return [self.index, self.columns] - - @property - def shape(self) -> tuple[int, int]: - """ - Return a tuple representing the dimensionality of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self), len(self.columns) - - def add_prefix(self, prefix): - """ - Prefix labels with string `prefix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string prefix values into str and adds it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(prefix), substring_type="prefix", axis=1 - ) - ) - - def add_suffix(self, suffix): - """ - Suffix labels with string `suffix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string suffix values into str and appends it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(suffix), substring_type="suffix", axis=1 - ) - ) - - @dataframe_not_implemented() - def map( - self, func, na_action: str | None = None, **kwargs - ) -> DataFrame: # pragma: no cover - if not callable(func): - raise ValueError(f"'{type(func)}' object is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.map(func, na_action=na_action, **kwargs) - ) - - def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not callable(func): - raise TypeError(f"{func} is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.applymap( - func, na_action=na_action, **kwargs - ) - ) - - def aggregate( - self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().aggregate(func, axis, *args, **kwargs) - - agg = aggregate - - def apply( - self, - func: AggFuncType | UserDefinedFunction, - axis: Axis = 0, - raw: bool = False, - result_type: Literal["expand", "reduce", "broadcast"] | None = None, - args=(), - **kwargs, - ): - """ - Apply a function along an axis of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - query_compiler = self._query_compiler.apply( - func, - axis, - raw=raw, - result_type=result_type, - args=args, - **kwargs, - ) - if not isinstance(query_compiler, type(self._query_compiler)): - # A scalar was returned - return query_compiler - - # If True, it is an unamed series. - # Theoretically, if df.apply returns a Series, it will only be an unnamed series - # because the function is supposed to be series -> scalar. - if query_compiler._modin_frame.is_unnamed_series(): - return Series(query_compiler=query_compiler) - else: - return self.__constructor__(query_compiler=query_compiler) - - def groupby( - self, - by=None, - axis: Axis | NoDefault = no_default, - level: IndexLabel | None = None, - as_index: bool = True, - sort: bool = True, - group_keys: bool = True, - observed: bool | NoDefault = no_default, - dropna: bool = True, - ): - """ - Group ``DataFrame`` using a mapper or by a ``Series`` of columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if axis is not no_default: - axis = self._get_axis_number(axis) - if axis == 1: - warnings.warn( - "DataFrame.groupby with axis=1 is deprecated. Do " - + "`frame.T.groupby(...)` without axis instead.", - FutureWarning, - stacklevel=1, - ) - else: - warnings.warn( - "The 'axis' keyword in DataFrame.groupby is deprecated and " - + "will be removed in a future version.", - FutureWarning, - stacklevel=1, - ) - else: - axis = 0 - - validate_groupby_args(by, level, observed) - - axis = self._get_axis_number(axis) - - if axis != 0 and as_index is False: - raise ValueError("as_index=False only valid for axis=0") - - idx_name = None - - if ( - not isinstance(by, Series) - and is_list_like(by) - and len(by) == 1 - # if by is a list-like of (None,), we have to keep it as a list because - # None may be referencing a column or index level whose label is - # `None`, and by=None wold mean that there is no `by` param. - and by[0] is not None - ): - by = by[0] - - if hashable(by) and ( - not callable(by) and not isinstance(by, (pandas.Grouper, FrozenList)) - ): - idx_name = by - elif isinstance(by, Series): - idx_name = by.name - if by._parent is self: - # if the SnowSeries comes from the current dataframe, - # convert it to labels directly for easy processing - by = by.name - elif is_list_like(by): - if axis == 0 and all( - ( - (hashable(o) and (o in self)) - or isinstance(o, Series) - or (is_list_like(o) and len(o) == len(self.shape[axis])) - ) - for o in by - ): - # plit 'by's into those that belongs to the self (internal_by) - # and those that doesn't (external_by). For SnowSeries that belongs - # to current DataFrame, we convert it to labels for easy process. - internal_by, external_by = [], [] - - for current_by in by: - if hashable(current_by): - internal_by.append(current_by) - elif isinstance(current_by, Series): - if current_by._parent is self: - internal_by.append(current_by.name) - else: - external_by.append(current_by) # pragma: no cover - else: - external_by.append(current_by) - - by = internal_by + external_by - - return DataFrameGroupBy( - self, - by, - axis, - level, - as_index, - sort, - group_keys, - idx_name, - observed=observed, - dropna=dropna, - ) - - def keys(self): # noqa: RT01, D200 - """ - Get columns of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns - - def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 - """ - Transpose index and columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy: - WarningMessage.ignored_argument( - operation="transpose", - argument="copy", - message="Transpose ignore copy argument in Snowpark pandas API", - ) - - if args: - WarningMessage.ignored_argument( - operation="transpose", - argument="args", - message="Transpose ignores args in Snowpark pandas API", - ) - - return self.__constructor__(query_compiler=self._query_compiler.transpose()) - - T = property(transpose) - - def add( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `add`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "add", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def assign(self, **kwargs): # noqa: PR01, RT01, D200 - """ - Assign new columns to a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - df = self.copy() - for k, v in kwargs.items(): - if callable(v): - df[k] = v(df) - else: - df[k] = v - return df - - @dataframe_not_implemented() - def boxplot( - self, - column=None, - by=None, - ax=None, - fontsize=None, - rot=0, - grid=True, - figsize=None, - layout=None, - return_type=None, - backend=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make a box plot from ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return to_pandas(self).boxplot( - column=column, - by=by, - ax=ax, - fontsize=fontsize, - rot=rot, - grid=grid, - figsize=figsize, - layout=layout, - return_type=return_type, - backend=backend, - **kwargs, - ) - - @dataframe_not_implemented() - def combine( - self, other, func, fill_value=None, overwrite=True - ): # noqa: PR01, RT01, D200 - """ - Perform column-wise combine with another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().combine(other, func, fill_value=fill_value, overwrite=overwrite) - - def compare( - self, - other, - align_axis=1, - keep_shape: bool = False, - keep_equal: bool = False, - result_names=("self", "other"), - ) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Compare to another ``DataFrame`` and show the differences. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - raise TypeError(f"Cannot compare DataFrame to {type(other)}") - other = self._validate_other(other, 0, compare_index=True) - return self.__constructor__( - query_compiler=self._query_compiler.compare( - other, - align_axis=align_axis, - keep_shape=keep_shape, - keep_equal=keep_equal, - result_names=result_names, - ) - ) - - def corr( - self, - method: str | Callable = "pearson", - min_periods: int | None = None, - numeric_only: bool = False, - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - corr_df = self - if numeric_only: - corr_df = self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - return self.__constructor__( - query_compiler=corr_df._query_compiler.corr( - method=method, - min_periods=min_periods, - ) - ) - - @dataframe_not_implemented() - def corrwith( - self, other, axis=0, drop=False, method="pearson", numeric_only=False - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, DataFrame): - other = other._query_compiler.to_pandas() - return self._default_to_pandas( - pandas.DataFrame.corrwith, - other, - axis=axis, - drop=drop, - method=method, - numeric_only=numeric_only, - ) - - @dataframe_not_implemented() - def cov( - self, - min_periods: int | None = None, - ddof: int | None = 1, - numeric_only: bool = False, - ): - """ - Compute pairwise covariance of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.cov( - min_periods=min_periods, - ddof=ddof, - numeric_only=numeric_only, - ) - ) - - @dataframe_not_implemented() - def dot(self, other): # noqa: PR01, RT01, D200 - """ - Compute the matrix multiplication between the ``DataFrame`` and `other`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if isinstance(other, BasePandasDataset): - common = self.columns.union(other.index) - if len(common) > len(self.columns) or len(common) > len( - other - ): # pragma: no cover - raise ValueError("Matrices are not aligned") - - if isinstance(other, DataFrame): - return self.__constructor__( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - else: - return self._reduce_dimension( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - - other = np.asarray(other) - if self.shape[1] != other.shape[0]: - raise ValueError( - f"Dot product shape mismatch, {self.shape} vs {other.shape}" - ) - - if len(other.shape) > 1: - return self.__constructor__( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - return self._reduce_dimension( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Perform equality comparison of ``DataFrame`` and `other` (binary operator `eq`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("eq", other, axis=axis, level=level) - - def equals(self, other) -> bool: # noqa: PR01, RT01, D200 - """ - Test whether two objects contain the same elements. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, pandas.DataFrame): - # Copy into a Modin DataFrame to simplify logic below - other = self.__constructor__(other) - - if ( - type(self) is not type(other) - or not self.index.equals(other.index) - or not self.columns.equals(other.columns) - ): - return False - - result = self.__constructor__( - query_compiler=self._query_compiler.equals(other._query_compiler) - ) - return result.all(axis=None) - - def _update_var_dicts_in_kwargs(self, expr, kwargs): - """ - Copy variables with "@" prefix in `local_dict` and `global_dict` keys of kwargs. - - Parameters - ---------- - expr : str - The expression string to search variables with "@" prefix. - kwargs : dict - See the documentation for eval() for complete details on the keyword arguments accepted by query(). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if "@" not in expr: - return - frame = sys._getframe() - try: - f_locals = frame.f_back.f_back.f_back.f_back.f_locals - f_globals = frame.f_back.f_back.f_back.f_back.f_globals - finally: - del frame - local_names = set(re.findall(r"@([\w]+)", expr)) - local_dict = {} - global_dict = {} - - for name in local_names: - for dct_out, dct_in in ((local_dict, f_locals), (global_dict, f_globals)): - try: - dct_out[name] = dct_in[name] - except KeyError: - pass - - if local_dict: - local_dict.update(kwargs.get("local_dict") or {}) - kwargs["local_dict"] = local_dict - if global_dict: - global_dict.update(kwargs.get("global_dict") or {}) - kwargs["global_dict"] = global_dict - - @dataframe_not_implemented() - def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Evaluate a string describing operations on ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - self._update_var_dicts_in_kwargs(expr, kwargs) - new_query_compiler = self._query_compiler.eval(expr, **kwargs) - return_type = type( - pandas.DataFrame(columns=self.columns) - .astype(self.dtypes) - .eval(expr, **kwargs) - ).__name__ - if return_type == type(self).__name__: - return self._create_or_update_from_compiler(new_query_compiler, inplace) - else: - if inplace: - raise ValueError("Cannot operate inplace if there is no assignment") - return getattr(sys.modules[self.__module__], return_type)( - query_compiler=new_query_compiler - ) - - def fillna( - self, - value: Hashable | Mapping | Series | DataFrame = None, - *, - method: FillnaOptions | None = None, - axis: Axis | None = None, - inplace: bool = False, - limit: int | None = None, - downcast: dict | None = None, - ) -> DataFrame | None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().fillna( - self_is_series=False, - value=value, - method=method, - axis=axis, - inplace=inplace, - limit=limit, - downcast=downcast, - ) - - def floordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `floordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "floordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @classmethod - @dataframe_not_implemented() - def from_dict( - cls, data, orient="columns", dtype=None, columns=None - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Construct ``DataFrame`` from dict of array-like or dicts. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_dict( - data, orient=orient, dtype=dtype, columns=columns - ) - ) - - @classmethod - @dataframe_not_implemented() - def from_records( - cls, - data, - index=None, - exclude=None, - columns=None, - coerce_float=False, - nrows=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert structured or record ndarray to ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_records( - data, - index=index, - exclude=exclude, - columns=columns, - coerce_float=coerce_float, - nrows=nrows, - ) - ) - - def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ge", other, axis=axis, level=level) - - def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("gt", other, axis=axis, level=level) - - @dataframe_not_implemented() - def hist( - self, - column=None, - by=None, - grid=True, - xlabelsize=None, - xrot=None, - ylabelsize=None, - yrot=None, - ax=None, - sharex=False, - sharey=False, - figsize=None, - layout=None, - bins=10, - **kwds, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Make a histogram of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.hist, - column=column, - by=by, - grid=grid, - xlabelsize=xlabelsize, - xrot=xrot, - ylabelsize=ylabelsize, - yrot=yrot, - ax=ax, - sharex=sharex, - sharey=sharey, - figsize=figsize, - layout=layout, - bins=bins, - **kwds, - ) - - def info( - self, - verbose: bool | None = None, - buf: IO[str] | None = None, - max_cols: int | None = None, - memory_usage: bool | str | None = None, - show_counts: bool | None = None, - null_counts: bool | None = None, - ): # noqa: PR01, D200 - """ - Print a concise summary of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def put_str(src, output_len=None, spaces=2): - src = str(src) - return src.ljust(output_len if output_len else len(src)) + " " * spaces - - def format_size(num): - for x in ["bytes", "KB", "MB", "GB", "TB"]: - if num < 1024.0: - return f"{num:3.1f} {x}" - num /= 1024.0 - return f"{num:3.1f} PB" - - output = [] - - type_line = str(type(self)) - index_line = "SnowflakeIndex" - columns = self.columns - columns_len = len(columns) - dtypes = self.dtypes - dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" - - if max_cols is None: - max_cols = 100 - - exceeds_info_cols = columns_len > max_cols - - if buf is None: - buf = sys.stdout - - if null_counts is None: - null_counts = not exceeds_info_cols - - if verbose is None: - verbose = not exceeds_info_cols - - if null_counts and verbose: - # We're gonna take items from `non_null_count` in a loop, which - # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here - # that will be faster. - non_null_count = self.count()._to_pandas() - - if memory_usage is None: - memory_usage = True - - def get_header(spaces=2): - output = [] - head_label = " # " - column_label = "Column" - null_label = "Non-Null Count" - dtype_label = "Dtype" - non_null_label = " non-null" - delimiter = "-" - - lengths = {} - lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) - lengths["column"] = max( - len(column_label), max(len(pprint_thing(col)) for col in columns) - ) - lengths["dtype"] = len(dtype_label) - dtype_spaces = ( - max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) - - lengths["dtype"] - ) - - header = put_str(head_label, lengths["head"]) + put_str( - column_label, lengths["column"] - ) - if null_counts: - lengths["null"] = max( - len(null_label), - max(len(pprint_thing(x)) for x in non_null_count) - + len(non_null_label), - ) - header += put_str(null_label, lengths["null"]) - header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) - - output.append(header) - - delimiters = put_str(delimiter * lengths["head"]) + put_str( - delimiter * lengths["column"] - ) - if null_counts: - delimiters += put_str(delimiter * lengths["null"]) - delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) - output.append(delimiters) - - return output, lengths - - output.extend([type_line, index_line]) - - def verbose_repr(output): - columns_line = f"Data columns (total {len(columns)} columns):" - header, lengths = get_header() - output.extend([columns_line, *header]) - for i, col in enumerate(columns): - i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) - - to_append = put_str(f" {i}", lengths["head"]) + put_str( - col_s, lengths["column"] - ) - if null_counts: - non_null = pprint_thing(non_null_count[col]) - to_append += put_str(f"{non_null} non-null", lengths["null"]) - to_append += put_str(dtype, lengths["dtype"], spaces=0) - output.append(to_append) - - def non_verbose_repr(output): - output.append(columns._summary(name="Columns")) - - if verbose: - verbose_repr(output) - else: - non_verbose_repr(output) - - output.append(dtypes_line) - - if memory_usage: - deep = memory_usage == "deep" - mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() - mem_line = f"memory usage: {format_size(mem_usage_bytes)}" - - output.append(mem_line) - - output.append("") - buf.write("\n".join(output)) - - def insert( - self, - loc: int, - column: Hashable, - value: Scalar | AnyArrayLike, - allow_duplicates: bool | NoDefault = no_default, - ) -> None: - """ - Insert column into ``DataFrame`` at specified location. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - raise_if_native_pandas_objects(value) - if allow_duplicates is no_default: - allow_duplicates = False - if not allow_duplicates and column in self.columns: - raise ValueError(f"cannot insert {column}, already exists") - - if not isinstance(loc, int): - raise TypeError("loc must be int") - - # If columns labels are multilevel, we implement following behavior (this is - # name native pandas): - # Case 1: if 'column' is tuple it's length must be same as number of levels - # otherwise raise error. - # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in - # empty strings to match the length of column levels in self frame. - if self.columns.nlevels > 1: - if isinstance(column, tuple) and len(column) != self.columns.nlevels: - # same error as native pandas. - raise ValueError("Item must have length equal to number of levels.") - if not isinstance(column, tuple): - # Fill empty strings to match length of levels - suffix = [""] * (self.columns.nlevels - 1) - column = tuple([column] + suffix) - - # Dictionary keys are treated as index column and this should be joined with - # index of target dataframe. This behavior is similar to 'value' being DataFrame - # or Series, so we simply create Series from dict data here. - if isinstance(value, dict): - value = Series(value, name=column) - - if isinstance(value, DataFrame) or ( - isinstance(value, np.ndarray) and len(value.shape) > 1 - ): - # Supported numpy array shapes are - # 1. (N, ) -> Ex. [1, 2, 3] - # 2. (N, 1) -> Ex> [[1], [2], [3]] - if value.shape[1] != 1: - if isinstance(value, DataFrame): - # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin - raise ValueError( - f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." - ) - else: - raise ValueError( - f"Expected a 1D array, got an array with shape {value.shape}" - ) - # Change numpy array shape from (N, 1) to (N, ) - if isinstance(value, np.ndarray): - value = value.squeeze(axis=1) - - if ( - is_list_like(value) - and not isinstance(value, (Series, DataFrame)) - and len(value) != self.shape[0] - and not 0 == self.shape[0] # dataframe holds no rows - ): - raise ValueError( - "Length of values ({}) does not match length of index ({})".format( - len(value), len(self) - ) - ) - if not -len(self.columns) <= loc <= len(self.columns): - raise IndexError( - f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" - ) - elif loc < 0: - raise ValueError("unbounded slice") - - join_on_index = False - if isinstance(value, (Series, DataFrame)): - value = value._query_compiler - join_on_index = True - elif is_list_like(value): - value = Series(value, name=column)._query_compiler - - new_query_compiler = self._query_compiler.insert( - loc, column, value, join_on_index - ) - # In pandas, 'insert' operation is always inplace. - self._update_inplace(new_query_compiler=new_query_compiler) - - @dataframe_not_implemented() - def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Fill NaN values using an interpolation method. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.interpolate, - method=method, - axis=axis, - limit=limit, - inplace=inplace, - limit_direction=limit_direction, - limit_area=limit_area, - downcast=downcast, - **kwargs, - ) - - def iterrows(self) -> Iterator[tuple[Hashable, Series]]: - """ - Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def iterrow_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - # Raise warning message since iterrows is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") - ) - - partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) - yield from partition_iterator - - def items(self): # noqa: D200 - """ - Iterate over (column name, ``Series``) pairs. - """ - - def items_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - partition_iterator = PartitionIterator(self, 1, items_builder) - yield from partition_iterator - - def itertuples( - self, index: bool = True, name: str | None = "Pandas" - ) -> Iterable[tuple[Any, ...]]: - """ - Iterate over ``DataFrame`` rows as ``namedtuple``-s. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - def itertuples_builder(s): - """Return the next namedtuple.""" - # s is the Series of values in the current row. - fields = [] # column names - data = [] # values under each column - - if index: - data.append(s.name) - fields.append("Index") - - # Fill column names and values. - fields.extend(list(self.columns)) - data.extend(s) - - if name is not None: - # Creating the namedtuple. - itertuple = collections.namedtuple(name, fields, rename=True) - return itertuple._make(data) - - # When the name is None, return a regular tuple. - return tuple(data) - - # Raise warning message since itertuples is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") - ) - return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) - - def join( - self, - other: DataFrame | Series | Iterable[DataFrame | Series], - on: IndexLabel | None = None, - how: str = "left", - lsuffix: str = "", - rsuffix: str = "", - sort: bool = False, - validate: str | None = None, - ) -> DataFrame: - """ - Join columns of another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - for o in other if isinstance(other, list) else [other]: - raise_if_native_pandas_objects(o) - - # Similar to native pandas we implement 'join' using 'pd.merge' method. - # Following code is copied from native pandas (with few changes explained below) - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 - if isinstance(other, Series): - # Same error as native pandas. - if other.name is None: - raise ValueError("Other Series must have a name") - other = DataFrame(other) - elif is_list_like(other): - if any([isinstance(o, Series) and o.name is None for o in other]): - raise ValueError("Other Series must have a name") - - if isinstance(other, DataFrame): - if how == "cross": - return pd.merge( - self, - other, - how=how, - on=on, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - return pd.merge( - self, - other, - left_on=on, - how=how, - left_index=on is None, - right_index=True, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - else: # List of DataFrame/Series - # Same error as native pandas. - if on is not None: - raise ValueError( - "Joining multiple DataFrames only supported for joining on index" - ) - - # Same error as native pandas. - if rsuffix or lsuffix: - raise ValueError( - "Suffixes not supported when joining multiple DataFrames" - ) - - # NOTE: These are not the differences between Snowpark pandas API and pandas behavior - # these are differences between native pandas join behavior when join - # frames have unique index or not. - - # In native pandas logic to join multiple DataFrames/Series is data - # dependent. Under the hood it will either use 'concat' or 'merge' API - # Case 1. If all objects being joined have unique index use 'concat' (axis=1) - # Case 2. Otherwise use 'merge' API by looping through objects left to right. - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 - - # Even though concat (axis=1) and merge are very similar APIs they have - # some differences which leads to inconsistent behavior in native pandas. - # 1. Treatment of un-named Series - # Case #1: Un-named series is allowed in concat API. Objects are joined - # successfully by assigning a number as columns name (see 'concat' API - # documentation for details on treatment of un-named series). - # Case #2: It raises 'ValueError: Other Series must have a name' - - # 2. how='right' - # Case #1: 'concat' API doesn't support right join. It raises - # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' - # Case #2: Merges successfully. - - # 3. Joining frames with duplicate labels but no conflict with other frames - # Example: self = DataFrame(... columns=["A", "B"]) - # other = [DataFrame(... columns=["C", "C"])] - # Case #1: 'ValueError: Indexes have overlapping values' - # Case #2: Merged successfully. - - # In addition to this, native pandas implementation also leads to another - # type of inconsistency where left.join(other, ...) and - # left.join([other], ...) might behave differently for cases mentioned - # above. - # Example: - # import pandas as pd - # df = pd.DataFrame({"a": [4, 5]}) - # other = pd.Series([1, 2]) - # df.join([other]) # this is successful - # df.join(other) # this raises 'ValueError: Other Series must have a name' - - # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API - # to join multiple DataFrame/Series. So always follow the behavior - # documented as Case #2 above. - - joined = self - for frame in other: - if isinstance(frame, DataFrame): - overlapping_cols = set(joined.columns).intersection( - set(frame.columns) - ) - if len(overlapping_cols) > 0: - # Native pandas raises: 'Indexes have overlapping values' - # We differ slightly from native pandas message to make it more - # useful to users. - raise ValueError( - f"Join dataframes have overlapping column labels: {overlapping_cols}" - ) - joined = pd.merge( - joined, - frame, - how=how, - left_index=True, - right_index=True, - validate=validate, - sort=sort, - suffixes=(None, None), - ) - return joined - - def isna(self): - return super().isna() - - def isnull(self): - return super().isnull() - - @dataframe_not_implemented() - def isetitem(self, loc, value): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.isetitem, - loc=loc, - value=value, - ) - - def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("le", other, axis=axis, level=level) - - def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("lt", other, axis=axis, level=level) - - def melt( - self, - id_vars=None, - value_vars=None, - var_name=None, - value_name="value", - col_level=None, - ignore_index=True, - ): # noqa: PR01, RT01, D200 - """ - Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if id_vars is None: - id_vars = [] - if not is_list_like(id_vars): - id_vars = [id_vars] - if value_vars is None: - # Behavior of Index.difference changed in 2.2.x - # https://github.com/pandas-dev/pandas/pull/55113 - # This change needs upstream to Modin: - # https://github.com/modin-project/modin/issues/7206 - value_vars = self.columns.drop(id_vars) - if var_name is None: - columns_name = self._query_compiler.get_index_name(axis=1) - var_name = columns_name if columns_name is not None else "variable" - return self.__constructor__( - query_compiler=self._query_compiler.melt( - id_vars=id_vars, - value_vars=value_vars, - var_name=var_name, - value_name=value_name, - col_level=col_level, - ignore_index=ignore_index, - ) - ) - - @dataframe_not_implemented() - def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 - """ - Return the memory usage of each column in bytes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if index: - result = self._reduce_dimension( - self._query_compiler.memory_usage(index=False, deep=deep) - ) - index_value = self.index.memory_usage(deep=deep) - return pd.concat( - [Series(index_value, index=["Index"]), result] - ) # pragma: no cover - return super().memory_usage(index=index, deep=deep) - - def merge( - self, - right: DataFrame | Series, - how: str = "inner", - on: IndexLabel | None = None, - left_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - right_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - left_index: bool = False, - right_index: bool = False, - sort: bool = False, - suffixes: Suffixes = ("_x", "_y"), - copy: bool = True, - indicator: bool = False, - validate: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Raise error if native pandas objects are passed. - raise_if_native_pandas_objects(right) - - if isinstance(right, Series) and right.name is None: - raise ValueError("Cannot merge a Series without a name") - if not isinstance(right, (Series, DataFrame)): - raise TypeError( - f"Can only merge Series or DataFrame objects, a {type(right)} was passed" - ) - - if isinstance(right, Series): - right_column_nlevels = ( - len(right.name) if isinstance(right.name, tuple) else 1 - ) - else: - right_column_nlevels = right.columns.nlevels - if self.columns.nlevels != right_column_nlevels: - # This is deprecated in native pandas. We raise explicit error for this. - raise ValueError( - "Can not merge objects with different column levels." - + f" ({self.columns.nlevels} levels on the left," - + f" {right_column_nlevels} on the right)" - ) - - # Merge empty native pandas dataframes for error checking. Otherwise, it will - # require a lot of logic to be written. This takes care of raising errors for - # following scenarios: - # 1. Only 'left_index' is set to True. - # 2. Only 'right_index is set to True. - # 3. Only 'left_on' is provided. - # 4. Only 'right_on' is provided. - # 5. 'on' and 'left_on' both are provided - # 6. 'on' and 'right_on' both are provided - # 7. 'on' and 'left_index' both are provided - # 8. 'on' and 'right_index' both are provided - # 9. 'left_on' and 'left_index' both are provided - # 10. 'right_on' and 'right_index' both are provided - # 11. Length mismatch between 'left_on' and 'right_on' - # 12. 'left_index' is not a bool - # 13. 'right_index' is not a bool - # 14. 'on' is not None and how='cross' - # 15. 'left_on' is not None and how='cross' - # 16. 'right_on' is not None and how='cross' - # 17. 'left_index' is True and how='cross' - # 18. 'right_index' is True and how='cross' - # 19. Unknown label in 'on', 'left_on' or 'right_on' - # 20. Provided 'suffixes' is not sufficient to resolve conflicts. - # 21. Merging on column with duplicate labels. - # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} - # 23. conflict with existing labels for array-like join key - # 24. 'indicator' argument is not bool or str - # 25. indicator column label conflicts with existing data labels - create_empty_native_pandas_frame(self).merge( - create_empty_native_pandas_frame(right), - on=on, - how=how, - left_on=replace_external_data_keys_with_empty_pandas_series(left_on), - right_on=replace_external_data_keys_with_empty_pandas_series(right_on), - left_index=left_index, - right_index=right_index, - suffixes=suffixes, - indicator=indicator, - ) - - return self.__constructor__( - query_compiler=self._query_compiler.merge( - right._query_compiler, - how=how, - on=on, - left_on=replace_external_data_keys_with_query_compiler(self, left_on), - right_on=replace_external_data_keys_with_query_compiler( - right, right_on - ), - left_index=left_index, - right_index=right_index, - sort=sort, - suffixes=suffixes, - copy=copy, - indicator=indicator, - validate=validate, - ) - ) - - def mod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `mod`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def mul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - multiply = mul - - def rmul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rmul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get not equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ne`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ne", other, axis=axis, level=level) - - def nlargest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in descending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nlargest(n, columns, keep) - ) - - def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in ascending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nsmallest( - n=n, columns=columns, keep=keep - ) - ) - - def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, - ): - """ - Pivot a level of the (necessarily hierarchical) index labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This ensures that non-pandas MultiIndex objects are caught. - nlevels = self._query_compiler.nlevels() - is_multiindex = nlevels > 1 - - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - - def pivot( - self, - *, - columns: Any, - index: Any | NoDefault = no_default, - values: Any | NoDefault = no_default, - ): - """ - Return reshaped DataFrame organized by given index / column values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if index is no_default: - index = None # pragma: no cover - if values is no_default: - values = None - - # if values is not specified, it should be the remaining columns not in - # index or columns - if values is None: - values = list(self.columns) - if index is not None: - values = [v for v in values if v not in index] - if columns is not None: - values = [v for v in values if v not in columns] - - return self.__constructor__( - query_compiler=self._query_compiler.pivot( - index=index, columns=columns, values=values - ) - ) - - def pivot_table( - self, - values=None, - index=None, - columns=None, - aggfunc="mean", - fill_value=None, - margins=False, - dropna=True, - margins_name="All", - observed=False, - sort=True, - ): - """ - Create a spreadsheet-style pivot table as a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - result = self.__constructor__( - query_compiler=self._query_compiler.pivot_table( - index=index, - values=values, - columns=columns, - aggfunc=aggfunc, - fill_value=fill_value, - margins=margins, - dropna=dropna, - margins_name=margins_name, - observed=observed, - sort=sort, - ) - ) - return result - - @dataframe_not_implemented() - @property - def plot( - self, - x=None, - y=None, - kind="line", - ax=None, - subplots=False, - sharex=None, - sharey=False, - layout=None, - figsize=None, - use_index=True, - title=None, - grid=None, - legend=True, - style=None, - logx=False, - logy=False, - loglog=False, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - rot=None, - fontsize=None, - colormap=None, - table=False, - yerr=None, - xerr=None, - secondary_y=False, - sort_columns=False, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make plots of ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().plot - - def pow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "pow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @dataframe_not_implemented() - def prod( - self, - axis=None, - skipna=True, - numeric_only=False, - min_count=0, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Return the product of the values over the requested axis. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - validate_bool_kwarg(skipna, "skipna", none_allowed=False) - axis = self._get_axis_number(axis) - axis_to_apply = self.columns if axis else self.index - if ( - skipna is not False - and numeric_only is None - and min_count > len(axis_to_apply) - ): - new_index = self.columns if not axis else self.index - return Series( - [np.nan] * len(new_index), index=new_index, dtype=np.dtype("object") - ) - - data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True) - if min_count > 1: - return data._reduce_dimension( - data._query_compiler.prod_min_count( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - return data._reduce_dimension( - data._query_compiler.prod( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - - product = prod - - def quantile( - self, - q: Scalar | ListLike = 0.5, - axis: Axis = 0, - numeric_only: bool = False, - interpolation: Literal[ - "linear", "lower", "higher", "midpoint", "nearest" - ] = "linear", - method: Literal["single", "table"] = "single", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().quantile( - q=q, - axis=axis, - numeric_only=numeric_only, - interpolation=interpolation, - method=method, - ) - - @dataframe_not_implemented() - def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Query the columns of a ``DataFrame`` with a boolean expression. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_var_dicts_in_kwargs(expr, kwargs) - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.query(expr, **kwargs) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rename( - self, - mapper: Renamer | None = None, - *, - index: Renamer | None = None, - columns: Renamer | None = None, - axis: Axis | None = None, - copy: bool | None = None, - inplace: bool = False, - level: Level | None = None, - errors: IgnoreRaise = "ignore", - ) -> DataFrame | None: - """ - Alter axes labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if mapper is None and index is None and columns is None: - raise TypeError("must pass an index to rename") - - if index is not None or columns is not None: - if axis is not None: - raise TypeError( - "Cannot specify both 'axis' and any of 'index' or 'columns'" - ) - elif mapper is not None: - raise TypeError( - "Cannot specify both 'mapper' and any of 'index' or 'columns'" - ) - else: - # use the mapper argument - if axis and self._get_axis_number(axis) == 1: - columns = mapper - else: - index = mapper - - if copy is not None: - WarningMessage.ignored_argument( - operation="dataframe.rename", - argument="copy", - message="copy parameter has been ignored with Snowflake execution engine", - ) - - if isinstance(index, dict): - index = Series(index) - - new_qc = self._query_compiler.rename( - index_renamer=index, columns_renamer=columns, level=level, errors=errors - ) - return self._create_or_update_from_compiler( - new_query_compiler=new_qc, inplace=inplace - ) - - def reindex( - self, - labels=None, - index=None, - columns=None, - axis=None, - method=None, - copy=None, - level=None, - fill_value=np.nan, - limit=None, - tolerance=None, - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - axis = self._get_axis_number(axis) - if axis == 0 and labels is not None: - index = labels - elif labels is not None: - columns = labels - return super().reindex( - index=index, - columns=columns, - method=method, - copy=copy, - level=level, - fill_value=fill_value, - limit=limit, - tolerance=tolerance, - ) - - @dataframe_not_implemented() - def reindex_like( - self, - other, - method=None, - copy: bool | None = None, - limit=None, - tolerance=None, - ) -> DataFrame: # pragma: no cover - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy is None: - copy = True - # docs say "Same as calling .reindex(index=other.index, columns=other.columns,...).": - # https://pandas.pydata.org/pandas-docs/version/1.4/reference/api/pandas.DataFrame.reindex_like.html - return self.reindex( - index=other.index, - columns=other.columns, - method=method, - copy=copy, - limit=limit, - tolerance=tolerance, - ) - - def replace( - self, - to_replace=None, - value=no_default, - inplace: bool = False, - limit=None, - regex: bool = False, - method: str | NoDefault = no_default, - ): - """ - Replace values given in `to_replace` with `value`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.replace( - to_replace=to_replace, - value=value, - limit=limit, - regex=regex, - method=method, - ) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rfloordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `rfloordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rfloordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def radd( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `radd`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "radd", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rmod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `rmod`). - """ - return self._binary_op( - "rmod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 - return super().round(decimals, args=args, **kwargs) - - def rpow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rpow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rsub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `rsub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rsub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rtruediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `rtruediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rtruediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - rdiv = rtruediv - - def select_dtypes( - self, - include: ListLike | str | type | None = None, - exclude: ListLike | str | type | None = None, - ) -> DataFrame: - """ - Return a subset of the ``DataFrame``'s columns based on the column dtypes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This line defers argument validation to pandas, which will raise errors on our behalf in cases - # like if `include` and `exclude` are None, the same type is specified in both lists, or a string - # dtype (as opposed to object) is specified. - pandas.DataFrame().select_dtypes(include, exclude) - - if include and not is_list_like(include): - include = [include] - elif include is None: - include = [] - if exclude and not is_list_like(exclude): - exclude = [exclude] - elif exclude is None: - exclude = [] - - sel = tuple(map(set, (include, exclude))) - - # The width of the np.int_/float_ alias differs between Windows and other platforms, so - # we need to include a workaround. - # https://github.com/numpy/numpy/issues/9464 - # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 - def check_sized_number_infer_dtypes(dtype): - if (isinstance(dtype, str) and dtype == "int") or (dtype is int): - return [np.int32, np.int64] - elif dtype == "float" or dtype is float: - return [np.float64, np.float32] - else: - return [infer_dtype_from_object(dtype)] - - include, exclude = map( - lambda x: set( - itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) - ), - sel, - ) - # We need to index on column position rather than label in case of duplicates - include_these = pandas.Series(not bool(include), index=range(len(self.columns))) - exclude_these = pandas.Series(not bool(exclude), index=range(len(self.columns))) - - def is_dtype_instance_mapper(dtype): - return functools.partial(issubclass, dtype.type) - - for i, dtype in enumerate(self.dtypes): - if include: - include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) - if exclude: - exclude_these[i] = not any( - map(is_dtype_instance_mapper(dtype), exclude) - ) - - dtype_indexer = include_these & exclude_these - indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] - # We need to use iloc instead of drop in case of duplicate column names - return self.iloc[:, indicate] - - def shift( - self, - periods: int | Sequence[int] = 1, - freq=None, - axis: Axis = 0, - fill_value: Hashable = no_default, - suffix: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().shift(periods, freq, axis, fill_value, suffix) - - def set_index( - self, - keys: IndexLabel - | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], - drop: bool = True, - append: bool = False, - inplace: bool = False, - verify_integrity: bool = False, - ) -> None | DataFrame: - """ - Set the ``DataFrame`` index using existing columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if not isinstance(keys, list): - keys = [keys] - - # make sure key is either hashable, index, or series - label_or_series = [] - - missing = [] - columns = self.columns.tolist() - for key in keys: - raise_if_native_pandas_objects(key) - if isinstance(key, pd.Series): - label_or_series.append(key._query_compiler) - elif isinstance(key, (np.ndarray, list, Iterator)): - label_or_series.append(pd.Series(key)._query_compiler) - elif isinstance(key, (pd.Index, pandas.MultiIndex)): - label_or_series += [ - s._query_compiler for s in self._to_series_list(key) - ] - else: - if not is_hashable(key): - raise TypeError( - f'The parameter "keys" may be a column key, one-dimensional array, or a list ' - f"containing only valid column keys and one-dimensional arrays. Received column " - f"of type {type(key)}" - ) - label_or_series.append(key) - found = key in columns - if columns.count(key) > 1: - raise ValueError(f"The column label '{key}' is not unique") - elif not found: - missing.append(key) - - if missing: - raise KeyError(f"None of {missing} are in the columns") - - new_query_compiler = self._query_compiler.set_index( - label_or_series, drop=drop, append=append - ) - - # TODO: SNOW-782633 improve this code once duplicate is supported - # this needs to pull all index which is inefficient - if verify_integrity and not new_query_compiler.index.is_unique: - duplicates = new_query_compiler.index[ - new_query_compiler.index.to_pandas().duplicated() - ].unique() - raise ValueError(f"Index has duplicate keys: {duplicates}") - - return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) - - sparse = CachedAccessor("sparse", SparseFrameAccessor) - - def squeeze(self, axis: Axis | None = None): - """ - Squeeze 1 dimensional axis objects into scalars. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) if axis is not None else None - len_columns = self._query_compiler.get_axis_len(1) - if axis == 1 and len_columns == 1: - return Series(query_compiler=self._query_compiler) - if axis in [0, None]: - # get_axis_len(0) results in a sql query to count number of rows in current - # dataframe. We should only compute len_index if axis is 0 or None. - len_index = len(self) - if axis is None and (len_columns == 1 or len_index == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 0 and len_index == 1: - return Series(query_compiler=self.T._query_compiler) - return self.copy() - - def stack( - self, - level: int | str | list = -1, - dropna: bool | NoDefault = no_default, - sort: bool | NoDefault = no_default, - future_stack: bool = False, # ignored - ): - """ - Stack the prescribed level(s) from columns to index. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if future_stack is not False: - WarningMessage.ignored_argument( # pragma: no cover - operation="DataFrame.stack", - argument="future_stack", - message="future_stack parameter has been ignored with Snowflake execution engine", - ) - if dropna is NoDefault: - dropna = True # pragma: no cover - if sort is NoDefault: - sort = True # pragma: no cover - - # This ensures that non-pandas MultiIndex objects are caught. - is_multiindex = len(self.columns.names) > 1 - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - - def sub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `sub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "sub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - subtract = sub - - @dataframe_not_implemented() - def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to the binary Feather format. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas(pandas.DataFrame.to_feather, path, **kwargs) - - @dataframe_not_implemented() - def to_gbq( - self, - destination_table, - project_id=None, - chunksize=None, - reauth=False, - if_exists="fail", - auth_local_webserver=True, - table_schema=None, - location=None, - progress_bar=True, - credentials=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to a Google BigQuery table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functionsf - return self._default_to_pandas( - pandas.DataFrame.to_gbq, - destination_table, - project_id=project_id, - chunksize=chunksize, - reauth=reauth, - if_exists=if_exists, - auth_local_webserver=auth_local_webserver, - table_schema=table_schema, - location=location, - progress_bar=progress_bar, - credentials=credentials, - ) - - @dataframe_not_implemented() - def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_orc, - path=path, - engine=engine, - index=index, - engine_kwargs=engine_kwargs, - ) - - @dataframe_not_implemented() - def to_html( - self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - bold_rows=True, - classes=None, - escape=True, - notebook=False, - border=None, - table_id=None, - render_links=False, - encoding=None, - ): # noqa: PR01, RT01, D200 - """ - Render a ``DataFrame`` as an HTML table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_html, - buf=buf, - columns=columns, - col_space=col_space, - header=header, - index=index, - na_rep=na_rep, - formatters=formatters, - float_format=float_format, - sparsify=sparsify, - index_names=index_names, - justify=justify, - max_rows=max_rows, - max_cols=max_cols, - show_dimensions=show_dimensions, - decimal=decimal, - bold_rows=bold_rows, - classes=classes, - escape=escape, - notebook=notebook, - border=border, - table_id=table_id, - render_links=render_links, - encoding=None, - ) - - @dataframe_not_implemented() - def to_parquet( - self, - path=None, - engine="auto", - compression="snappy", - index=None, - partition_cols=None, - storage_options: StorageOptions = None, - **kwargs, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - from snowflake.snowpark.modin.pandas.dispatching.factories.dispatcher import ( - FactoryDispatcher, - ) - - return FactoryDispatcher.to_parquet( - self._query_compiler, - path=path, - engine=engine, - compression=compression, - index=index, - partition_cols=partition_cols, - storage_options=storage_options, - **kwargs, - ) - - @dataframe_not_implemented() - def to_period( - self, freq=None, axis=0, copy=True - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` from ``DatetimeIndex`` to ``PeriodIndex``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_period(freq=freq, axis=axis, copy=copy) - - @dataframe_not_implemented() - def to_records( - self, index=True, column_dtypes=None, index_dtypes=None - ): # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` to a NumPy record array. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_records, - index=index, - column_dtypes=column_dtypes, - index_dtypes=index_dtypes, - ) - - @dataframe_not_implemented() - def to_stata( - self, - path: FilePath | WriteBuffer[bytes], - convert_dates: dict[Hashable, str] | None = None, - write_index: bool = True, - byteorder: str | None = None, - time_stamp: datetime.datetime | None = None, - data_label: str | None = None, - variable_labels: dict[Hashable, str] | None = None, - version: int | None = 114, - convert_strl: Sequence[Hashable] | None = None, - compression: CompressionOptions = "infer", - storage_options: StorageOptions = None, - *, - value_labels: dict[Hashable, dict[float | int, str]] | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_stata, - path, - convert_dates=convert_dates, - write_index=write_index, - byteorder=byteorder, - time_stamp=time_stamp, - data_label=data_label, - variable_labels=variable_labels, - version=version, - convert_strl=convert_strl, - compression=compression, - storage_options=storage_options, - value_labels=value_labels, - ) - - @dataframe_not_implemented() - def to_xml( - self, - path_or_buffer=None, - index=True, - root_name="data", - row_name="row", - na_rep=None, - attr_cols=None, - elem_cols=None, - namespaces=None, - prefix=None, - encoding="utf-8", - xml_declaration=True, - pretty_print=True, - parser="lxml", - stylesheet=None, - compression="infer", - storage_options=None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.default_to_pandas( - pandas.DataFrame.to_xml, - path_or_buffer=path_or_buffer, - index=index, - root_name=root_name, - row_name=row_name, - na_rep=na_rep, - attr_cols=attr_cols, - elem_cols=elem_cols, - namespaces=namespaces, - prefix=prefix, - encoding=encoding, - xml_declaration=xml_declaration, - pretty_print=pretty_print, - parser=parser, - stylesheet=stylesheet, - compression=compression, - storage_options=storage_options, - ) - ) - - def to_dict( - self, - orient: Literal[ - "dict", "list", "series", "split", "tight", "records", "index" - ] = "dict", - into: type[dict] = dict, - ) -> dict | list[dict]: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().to_dict(orient=orient, into=into) - - def to_timestamp( - self, freq=None, how="start", axis=0, copy=True - ): # noqa: PR01, RT01, D200 - """ - Cast to DatetimeIndex of timestamps, at *beginning* of period. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_timestamp(freq=freq, how=how, axis=axis, copy=copy) - - def truediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `truediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "truediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - div = divide = truediv - - def update( - self, other, join="left", overwrite=True, filter_func=None, errors="ignore" - ): # noqa: PR01, RT01, D200 - """ - Modify in place using non-NA values from another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - other = self.__constructor__(other) - query_compiler = self._query_compiler.df_update( - other._query_compiler, - join=join, - overwrite=overwrite, - filter_func=filter_func, - errors=errors, - ) - self._update_inplace(new_query_compiler=query_compiler) - - def diff( - self, - periods: int = 1, - axis: Axis = 0, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().diff( - periods=periods, - axis=axis, - ) - - def drop( - self, - labels: IndexLabel = None, - axis: Axis = 0, - index: IndexLabel = None, - columns: IndexLabel = None, - level: Level = None, - inplace: bool = False, - errors: IgnoreRaise = "raise", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().drop( - labels=labels, - axis=axis, - index=index, - columns=columns, - level=level, - inplace=inplace, - errors=errors, - ) - - def value_counts( - self, - subset: Sequence[Hashable] | None = None, - normalize: bool = False, - sort: bool = True, - ascending: bool = False, - dropna: bool = True, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series( - query_compiler=self._query_compiler.value_counts( - subset=subset, - normalize=normalize, - sort=sort, - ascending=ascending, - dropna=dropna, - ), - name="proportion" if normalize else "count", - ) - - def mask( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.mask requires an axis parameter (0 or 1) when given a Series" - ) - - return super().mask( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - def where( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - """ - Replace values where the condition is False. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.where requires an axis parameter (0 or 1) when given a Series" - ) - - return super().where( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - @dataframe_not_implemented() - def xs(self, key, axis=0, level=None, drop_level=True): # noqa: PR01, RT01, D200 - """ - Return cross-section from the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.xs, key, axis=axis, level=level, drop_level=drop_level - ) - - def set_axis( - self, - labels: IndexLabel, - *, - axis: Axis = 0, - copy: bool | NoDefault = no_default, # ignored - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not is_scalar(axis): - raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") - return super().set_axis( - labels=labels, - # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. - axis=pandas.DataFrame._get_axis_name(axis), - copy=copy, - ) - - def __getattr__(self, key): - """ - Return item identified by `key`. - - Parameters - ---------- - key : hashable - Key to get. - - Returns - ------- - Any - - Notes - ----- - First try to use `__getattribute__` method. If it fails - try to get `key` from ``DataFrame`` fields. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - try: - return object.__getattribute__(self, key) - except AttributeError as err: - if key not in _ATTRS_NO_LOOKUP and key in self.columns: - return self[key] - raise err - - def __setattr__(self, key, value): - """ - Set attribute `value` identified by `key`. - - Parameters - ---------- - key : hashable - Key to set. - value : Any - Value to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # While we let users assign to a column labeled "x" with "df.x" , there - # are some attributes that we should assume are NOT column names and - # therefore should follow the default Python object assignment - # behavior. These are: - # - anything in self.__dict__. This includes any attributes that the - # user has added to the dataframe with, e.g., `df.c = 3`, and - # any attribute that Modin has added to the frame, e.g. - # `_query_compiler` and `_siblings` - # - `_query_compiler`, which Modin initializes before it appears in - # __dict__ - # - `_siblings`, which Modin initializes before it appears in __dict__ - # - `_cache`, which pandas.cache_readonly uses to cache properties - # before it appears in __dict__. - if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__: - pass - elif key in self and key not in dir(self): - self.__setitem__(key, value) - # Note: return immediately so we don't keep this `key` as dataframe state. - # `__getattr__` will return the columns not present in `dir(self)`, so we do not need - # to manually track this state in the `dir`. - return - elif is_list_like(value) and key not in ["index", "columns"]: - WarningMessage.single_warning( - SET_DATAFRAME_ATTRIBUTE_WARNING - ) # pragma: no cover - object.__setattr__(self, key, value) - - def __setitem__(self, key: Any, value: Any): - """ - Set attribute `value` identified by `key`. - - Args: - key: Key to set - value: Value to set - - Note: - In the case where value is any list like or array, pandas checks the array length against the number of rows - of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw - a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if - the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use - enlargement filling with the last value in the array. - - Returns: - None - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - key = apply_if_callable(key, self) - if isinstance(key, DataFrame) or ( - isinstance(key, np.ndarray) and len(key.shape) == 2 - ): - # This case uses mask's codepath to perform the set, but - # we need to duplicate the code here since we are passing - # an additional kwarg `cond_fillna_with_true` to the QC here. - # We need this additional kwarg, since if df.shape - # and key.shape do not align (i.e. df has more rows), - # mask's codepath would mask the additional rows in df - # while for setitem, we need to keep the original values. - if not isinstance(key, DataFrame): - if key.dtype != bool: - raise TypeError( - "Must pass DataFrame or 2-d ndarray with boolean values only" - ) - key = DataFrame(key) - key._query_compiler._shape_hint = "array" - - if value is not None: - value = apply_if_callable(value, self) - - if isinstance(value, np.ndarray): - value = DataFrame(value) - value._query_compiler._shape_hint = "array" - elif isinstance(value, pd.Series): - # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this - # error instead, since it is more descriptive. - raise ValueError( - "setitem with a 2D key does not support Series values." - ) - - if isinstance(value, BasePandasDataset): - value = value._query_compiler - - query_compiler = self._query_compiler.mask( - cond=key._query_compiler, - other=value, - axis=None, - level=None, - cond_fillna_with_true=True, - ) - - return self._create_or_update_from_compiler(query_compiler, inplace=True) - - # Error Checking: - if (isinstance(key, pd.Series) or is_list_like(key)) and ( - isinstance(value, range) - ): - raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) - elif isinstance(value, slice): - # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. - raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) - - # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column - # key. - index, columns = slice(None), key - index_is_bool_indexer = False - if isinstance(key, slice): - if is_integer(key.start) and is_integer(key.stop): - # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as - # df.iloc[1:2, :] = val - self.iloc[key] = value - return - index, columns = key, slice(None) - elif isinstance(key, pd.Series): - if is_bool_dtype(key.dtype): - index, columns = key, slice(None) - index_is_bool_indexer = True - elif is_bool_indexer(key): - index, columns = pd.Series(key), slice(None) - index_is_bool_indexer = True - - # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case - # we have to explicitly set matching_item_columns_by_label to False for setitem. - index = index._query_compiler if isinstance(index, BasePandasDataset) else index - columns = ( - columns._query_compiler - if isinstance(columns, BasePandasDataset) - else columns - ) - from .indexing import is_2d_array - - matching_item_rows_by_label = not is_2d_array(value) - if is_2d_array(value): - value = DataFrame(value) - item = value._query_compiler if isinstance(value, BasePandasDataset) else value - new_qc = self._query_compiler.set_2d_labels( - index, - columns, - item, - # setitem always matches item by position - matching_item_columns_by_label=False, - matching_item_rows_by_label=matching_item_rows_by_label, - index_is_bool_indexer=index_is_bool_indexer, - # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling - # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the - # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have - # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns - # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", - # "X", "X". - deduplicate_columns=True, - ) - return self._update_inplace(new_query_compiler=new_qc) - - def abs(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().abs() - - def __and__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__and__", other, axis=1) - - def __rand__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__rand__", other, axis=1) - - def __or__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__or__", other, axis=1) - - def __ror__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__ror__", other, axis=1) - - def __neg__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().__neg__() - - def __iter__(self): - """ - Iterate over info axis. - - Returns - ------- - iterable - Iterator of the columns names. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return iter(self.columns) - - def __contains__(self, key): - """ - Check if `key` in the ``DataFrame.columns``. - - Parameters - ---------- - key : hashable - Key to check the presence in the columns. - - Returns - ------- - bool - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns.__contains__(key) - - def __round__(self, decimals=0): - """ - Round each value in a ``DataFrame`` to the given number of decimals. - - Parameters - ---------- - decimals : int, default: 0 - Number of decimal places to round to. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().round(decimals) - - @dataframe_not_implemented() - def __delitem__(self, key): - """ - Delete item identified by `key` label. - - Parameters - ---------- - key : hashable - Key to delete. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if key not in self: - raise KeyError(key) - self._update_inplace(new_query_compiler=self._query_compiler.delitem(key)) - - __add__ = add - __iadd__ = add # pragma: no cover - __radd__ = radd - __mul__ = mul - __imul__ = mul # pragma: no cover - __rmul__ = rmul - __pow__ = pow - __ipow__ = pow # pragma: no cover - __rpow__ = rpow - __sub__ = sub - __isub__ = sub # pragma: no cover - __rsub__ = rsub - __floordiv__ = floordiv - __ifloordiv__ = floordiv # pragma: no cover - __rfloordiv__ = rfloordiv - __truediv__ = truediv - __itruediv__ = truediv # pragma: no cover - __rtruediv__ = rtruediv - __mod__ = mod - __imod__ = mod # pragma: no cover - __rmod__ = rmod - __rdiv__ = rdiv - - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): - """ - Get a Modin DataFrame that implements the dataframe exchange protocol. - - See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. - - Parameters - ---------- - nan_as_null : bool, default: False - A keyword intended for the consumer to tell the producer - to overwrite null values in the data with ``NaN`` (or ``NaT``). - This currently has no effect; once support for nullable extension - dtypes is added, this value should be propagated to columns. - allow_copy : bool, default: True - A keyword that defines whether or not the library is allowed - to make a copy of the data. For example, copying data would be necessary - if a library supports strided buffers, given that this protocol - specifies contiguous buffers. Currently, if the flag is set to ``False`` - and a copy is needed, a ``RuntimeError`` will be raised. - - Returns - ------- - ProtocolDataframe - A dataframe object following the dataframe protocol specification. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - ErrorMessage.not_implemented( - "Snowpark pandas does not support the DataFrame interchange " - + "protocol method `__dataframe__`. To use Snowpark pandas " - + "DataFrames with third-party libraries that try to call the " - + "`__dataframe__` method, please convert this Snowpark pandas " - + "DataFrame to pandas with `to_pandas()`." - ) - - return self._query_compiler.to_dataframe( - nan_as_null=nan_as_null, allow_copy=allow_copy - ) - - @dataframe_not_implemented() - @property - def attrs(self): # noqa: RT01, D200 - """ - Return dictionary of global attributes of this dataset. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def attrs(df): - return df.attrs - - return self._default_to_pandas(attrs) - - @dataframe_not_implemented() - @property - def style(self): # noqa: RT01, D200 - """ - Return a Styler object. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def style(df): - """Define __name__ attr because properties do not have it.""" - return df.style - - return self._default_to_pandas(style) - - def isin( - self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(values, dict): - return super().isin(values) - elif isinstance(values, Series): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not values.index.is_unique: - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - elif isinstance(values, DataFrame): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not (values.columns.is_unique and values.index.is_unique): - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - else: - if not is_list_like(values): - # throw pandas compatible error - raise TypeError( - "only list-like or dict-like objects are allowed " - f"to be passed to {self.__class__.__name__}.isin(), " - f"you passed a '{type(values).__name__}'" - ) - return super().isin(values) - - def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a ``DataFrame`` with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - DataFrame or None - None if update was done, ``DataFrame`` otherwise. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace: - return self.__constructor__(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - def _get_numeric_data(self, axis: int): - """ - Grab only numeric data from ``DataFrame``. - - Parameters - ---------- - axis : {0, 1} - Axis to inspect on having numeric types only. - - Returns - ------- - DataFrame - ``DataFrame`` with numeric data. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas ignores `numeric_only` if `axis` is 1, but we do have to drop - # non-numeric columns if `axis` is 0. - if axis != 0: - return self - return self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - - def _validate_dtypes(self, numeric_only=False): - """ - Check that all the dtypes are the same. - - Parameters - ---------- - numeric_only : bool, default: False - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - dtype = self.dtypes[0] - for t in self.dtypes: - if numeric_only and not is_numeric_dtype(t): - raise TypeError(f"{t} is not a numeric data type") - elif not numeric_only and t != dtype: - raise TypeError(f"Cannot compare type '{t}' with type '{dtype}'") - - def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False): - """ - Validate data dtype for `sum`, `prod` and `mean` methods. - - Parameters - ---------- - axis : {0, 1} - Axis to validate over. - numeric_only : bool - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - ignore_axis : bool, default: False - Whether or not to ignore `axis` parameter. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # We cannot add datetime types, so if we are summing a column with - # dtype datetime64 and cannot ignore non-numeric types, we must throw a - # TypeError. - if ( - not axis - and numeric_only is False - and any(dtype == np.dtype("datetime64[ns]") for dtype in self.dtypes) - ): - raise TypeError("Cannot add Timestamp Types") - - # If our DataFrame has both numeric and non-numeric dtypes then - # operations between these types do not make sense and we must raise a - # TypeError. The exception to this rule is when there are datetime and - # timedelta objects, in which case we proceed with the comparison - # without ignoring any non-numeric types. We must check explicitly if - # numeric_only is False because if it is None, it will default to True - # if the operation fails with mixed dtypes. - if ( - (axis or ignore_axis) - and numeric_only is False - and np.unique([is_numeric_dtype(dtype) for dtype in self.dtypes]).size == 2 - ): - # check if there are columns with dtypes datetime or timedelta - if all( - dtype != np.dtype("datetime64[ns]") - and dtype != np.dtype("timedelta64[ns]") - for dtype in self.dtypes - ): - raise TypeError("Cannot operate on Numeric and Non-Numeric Types") - - return self._get_numeric_data(axis) if numeric_only else self - - def _to_pandas( - self, - *, - statement_params: dict[str, str] | None = None, - **kwargs: Any, - ) -> pandas.DataFrame: - """ - Convert Snowpark pandas DataFrame to pandas DataFrame - - Args: - statement_params: Dictionary of statement level parameters to be set while executing this action. - - Returns: - pandas DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.to_pandas( - statement_params=statement_params, **kwargs - ) - - def _validate_eval_query(self, expr, **kwargs): - """ - Validate the arguments of ``eval`` and ``query`` functions. - - Parameters - ---------- - expr : str - The expression to evaluate. This string cannot contain any - Python statements, only Python expressions. - **kwargs : dict - Optional arguments of ``eval`` and ``query`` functions. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(expr, str) and expr == "": - raise ValueError("expr cannot be an empty string") - - if isinstance(expr, str) and "not" in expr: - if "parser" in kwargs and kwargs["parser"] == "python": - ErrorMessage.not_implemented( # pragma: no cover - "Snowpark pandas does not yet support 'not' in the " - + "expression for the methods `DataFrame.eval` or " - + "`DataFrame.query`" - ) - - def _reduce_dimension(self, query_compiler): - """ - Reduce the dimension of data from the `query_compiler`. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to retrieve the data. - - Returns - ------- - Series - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series(query_compiler=query_compiler) - - def _set_axis_name(self, name, axis=0, inplace=False): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - renamed = self if inplace else self.copy() - if axis == 0: - renamed.index = renamed.index.set_names(name) - else: - renamed.columns = renamed.columns.set_names(name) - if not inplace: - return renamed - - def _to_datetime(self, **kwargs): - """ - Convert `self` to datetime. - - Parameters - ---------- - **kwargs : dict - Optional arguments to use during query compiler's - `to_datetime` invocation. - - Returns - ------- - Series of datetime64 dtype - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._reduce_dimension( - query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) - ) - - # Persistance support methods - BEGIN - @classmethod - def _inflate_light(cls, query_compiler): - """ - Re-creates the object from previously-serialized lightweight representation. - - The method is used for faster but not disk-storable persistence. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `query_compiler`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(query_compiler=query_compiler) - - @classmethod - def _inflate_full(cls, pandas_df): - """ - Re-creates the object from previously-serialized disk-storable representation. - - Parameters - ---------- - pandas_df : pandas.DataFrame - Data to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `pandas_df`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(data=from_pandas(pandas_df)) - - @dataframe_not_implemented() - def __reduce__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._query_compiler.finalize() - # if PersistentPickle.get(): - # return self._inflate_full, (self._to_pandas(),) - return self._inflate_light, (self._query_compiler,) - - # Persistance support methods - END diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index 2ca9d8e5b83..5024d0618ac 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -31,7 +31,7 @@ import numpy as np import pandas import pandas.core.common as common -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import IntervalIndex, NaT, Timedelta, Timestamp from pandas._libs import NaTType, lib @@ -65,7 +65,6 @@ # add this line to make doctests runnable from snowflake.snowpark.modin import pandas as pd # noqa: F401 -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import ( is_scalar, raise_if_native_pandas_objects, @@ -92,10 +91,9 @@ # linking to `snowflake.snowpark.DataFrame`, we need to explicitly # qualify return types in this file with `modin.pandas.DataFrame`. # SNOW-1233342: investigate how to fix these links without using absolute paths + import modin from modin.core.storage_formats import BaseQueryCompiler # pragma: no cover - import snowflake # pragma: no cover - _logger = getLogger(__name__) VALID_DATE_TYPE = Union[ @@ -137,8 +135,8 @@ def notna(obj): # noqa: PR01, RT01, D200 @snowpark_pandas_telemetry_standalone_function_decorator def merge( - left: snowflake.snowpark.modin.pandas.DataFrame | Series, - right: snowflake.snowpark.modin.pandas.DataFrame | Series, + left: modin.pandas.DataFrame | Series, + right: modin.pandas.DataFrame | Series, how: str | None = "inner", on: IndexLabel | None = None, left_on: None @@ -414,7 +412,7 @@ def merge_asof( tolerance: int | Timedelta | None = None, allow_exact_matches: bool = True, direction: str = "backward", -) -> snowflake.snowpark.modin.pandas.DataFrame: +) -> modin.pandas.DataFrame: """ Perform a merge by key distance. @@ -1105,8 +1103,8 @@ def value_counts( @snowpark_pandas_telemetry_standalone_function_decorator def concat( objs: ( - Iterable[snowflake.snowpark.modin.pandas.DataFrame | Series] - | Mapping[Hashable, snowflake.snowpark.modin.pandas.DataFrame | Series] + Iterable[modin.pandas.DataFrame | Series] + | Mapping[Hashable, modin.pandas.DataFrame | Series] ), axis: Axis = 0, join: str = "outer", @@ -1117,7 +1115,7 @@ def concat( verify_integrity: bool = False, sort: bool = False, copy: bool = True, -) -> snowflake.snowpark.modin.pandas.DataFrame | Series: +) -> modin.pandas.DataFrame | Series: """ Concatenate pandas objects along a particular axis. @@ -1490,7 +1488,7 @@ def concat( def to_datetime( arg: DatetimeScalarOrArrayConvertible | DictConvertible - | snowflake.snowpark.modin.pandas.DataFrame + | modin.pandas.DataFrame | Series, errors: DateTimeErrorChoices = "raise", dayfirst: bool = False, diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index c672f04da63..5da10d9b7a6 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -45,6 +45,7 @@ import pandas from modin.pandas import Series from modin.pandas.base import BasePandasDataset +from modin.pandas.dataframe import DataFrame from pandas._libs.tslibs import Resolution, parsing from pandas._typing import AnyArrayLike, Scalar from pandas.api.types import is_bool, is_list_like @@ -61,7 +62,6 @@ import snowflake.snowpark.modin.pandas as pd import snowflake.snowpark.modin.pandas.utils as frontend_utils -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin._internal.indexing_utils import ( MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE, diff --git a/src/snowflake/snowpark/modin/pandas/io.py b/src/snowflake/snowpark/modin/pandas/io.py index 25959212a18..b92e8ee3582 100644 --- a/src/snowflake/snowpark/modin/pandas/io.py +++ b/src/snowflake/snowpark/modin/pandas/io.py @@ -92,7 +92,7 @@ # below logic is to handle circular imports without errors if TYPE_CHECKING: # pragma: no cover - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available @@ -106,7 +106,7 @@ class ModinObjects: def DataFrame(cls): """Get ``modin.pandas.DataFrame`` class.""" if cls._dataframe is None: - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame cls._dataframe = DataFrame return cls._dataframe diff --git a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py index 3529355b81b..ee782f3cdf3 100644 --- a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py +++ b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py @@ -5,10 +5,9 @@ from collections.abc import Iterator from typing import Any, Callable +import modin.pandas.dataframe as DataFrame import pandas -import snowflake.snowpark.modin.pandas.dataframe as DataFrame - PARTITION_SIZE = 4096 diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index 3986e3d52a9..a48f16992d4 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -78,7 +78,7 @@ def from_non_pandas(df, index, columns, dtype): new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype) if new_qc is not None: - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=new_qc) return new_qc @@ -99,7 +99,7 @@ def from_pandas(df): A new Modin DataFrame object. """ # from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) @@ -118,10 +118,11 @@ def from_arrow(at): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_arrow(at)) @@ -142,10 +143,11 @@ def from_dataframe(df): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df)) @@ -226,7 +228,7 @@ def from_modin_frame_to_mi(df, sortorder=None, names=None): pandas.MultiIndex The pandas.MultiIndex representation of the given DataFrame. """ - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame if isinstance(df, DataFrame): df = df._to_pandas() diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index d3ac525572a..eceb9ca7d7f 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -69,6 +69,7 @@ inherit_modules = [ (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset), + (docstrings.dataframe.DataFrame, modin.pandas.dataframe.DataFrame), (docstrings.series.Series, modin.pandas.series.Series), (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods), ( @@ -90,17 +91,3 @@ snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = ( False ) - - -# TODO: SNOW-1504302: Modin upgrade - use Snowpark pandas DataFrame for isocalendar -# OSS Modin's DatetimeProperties frontend class wraps the returned query compiler with `modin.pandas.DataFrame`. -# Since we currently replace `pd.DataFrame` with our own Snowpark pandas DataFrame object, this causes errors -# since OSS Modin explicitly imports its own DataFrame class here. This override can be removed once the frontend -# DataFrame class is removed from our codebase. -def isocalendar(self): # type: ignore - from snowflake.snowpark.modin.pandas import DataFrame - - return DataFrame(query_compiler=self._query_compiler.dt_isocalendar()) - - -modin.pandas.series_utils.DatetimeProperties.isocalendar = isocalendar diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index d38584c14de..e19a6de37ba 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -567,9 +567,7 @@ def __new__( attrs (Dict[str, Any]): The attributes of the class. Returns: - Union[snowflake.snowpark.modin.pandas.series.Series, - snowflake.snowpark.modin.pandas.dataframe.DataFrame, - snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, + Union[snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, snowflake.snowpark.modin.pandas.resample.Resampler, snowflake.snowpark.modin.pandas.window.Window, snowflake.snowpark.modin.pandas.window.Rolling]: diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 9f01954ab2c..70025fd8b0a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -1276,7 +1276,7 @@ def check_snowpark_pandas_object_in_arg(arg: Any) -> bool: if check_snowpark_pandas_object_in_arg(v): return True else: - from snowflake.snowpark.modin.pandas import DataFrame, Series + from modin.pandas import DataFrame, Series return isinstance(arg, (DataFrame, Series)) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 48f91ab40dd..e971b15b6d6 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -2037,8 +2037,8 @@ def binary_op( # Native pandas does not support binary operations between a Series and a list-like object. from modin.pandas import Series + from modin.pandas.dataframe import DataFrame - from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar # fail explicitly for unsupported scenarios diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/base.py b/src/snowflake/snowpark/modin/plugin/docstrings/base.py index af50e0379dd..4044f7b675f 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/base.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/base.py @@ -2832,6 +2832,7 @@ def shift(): """ Implement shared functionality between DataFrame and Series for shift. axis argument is only relevant for Dataframe, and should be 0 for Series. + Args: periods : int | Sequence[int] Number of periods to shift. Can be positive or negative. If an iterable of ints, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 6223e9dd273..f7e93e6c2df 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -1749,7 +1749,7 @@ def info(): ... 'COL2': ['A', 'B', 'C']}) >>> df.info() # doctest: +NORMALIZE_WHITESPACE - + SnowflakeIndex Data columns (total 2 columns): # Column Non-Null Count Dtype diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index aeca9d6e305..ecef6e843ba 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -60,7 +60,6 @@ validate_percentile, ) -import snowflake.snowpark.modin.pandas as spd from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, @@ -88,8 +87,6 @@ def register_base_override(method_name: str): for directly overriding methods on BasePandasDataset, we mock this by performing the override on DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary to allow both the docstring extension and method dispatch to work properly. - - Methods annotated here also are automatically instrumented with Snowpark pandas telemetry. """ def decorator(base_method: Any): @@ -103,10 +100,7 @@ def decorator(base_method: Any): series_method = series_method.fget if series_method is None or series_method is parent_method: register_series_accessor(method_name)(base_method) - # TODO: SNOW-1063346 - # Since we still use the vendored version of DataFrame and the overrides for the top-level - # namespace haven't been performed yet, we need to set properties on the vendored version - df_method = getattr(spd.dataframe.DataFrame, method_name, None) + df_method = getattr(pd.DataFrame, method_name, None) if isinstance(df_method, property): df_method = df_method.fget if df_method is None or df_method is parent_method: @@ -176,6 +170,22 @@ def filter( pass # pragma: no cover +@register_base_not_implemented() +def interpolate( + self, + method="linear", + *, + axis=0, + limit=None, + inplace=False, + limit_direction: str | None = None, + limit_area=None, + downcast=lib.no_default, + **kwargs, +): # noqa: PR01, RT01, D200 + pass + + @register_base_not_implemented() def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200 pass # pragma: no cover @@ -813,7 +823,7 @@ def _binary_op( **kwargs, ) - from snowflake.snowpark.modin.pandas.dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # Modin Bug: https://github.com/modin-project/modin/issues/7236 # For a Series interacting with a DataFrame, always return a DataFrame diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py index 5ce836061ab..62c9cab4dc1 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py @@ -7,20 +7,1443 @@ pandas, such as `DataFrame.memory_usage`. """ -from typing import Any, Union +from __future__ import annotations +import collections +import datetime +import functools +import itertools +import sys +import warnings +from typing import ( + IO, + Any, + Callable, + Hashable, + Iterable, + Iterator, + Literal, + Mapping, + Sequence, +) + +import modin.pandas as pd +import numpy as np import pandas as native_pd -from modin.pandas import DataFrame -from pandas._typing import Axis, PythonFuncType -from pandas.core.dtypes.common import is_dict_like, is_list_like +from modin.pandas import DataFrame, Series +from modin.pandas.base import BasePandasDataset +from pandas._libs.lib import NoDefault, no_default +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, + Axis, + CompressionOptions, + FilePath, + FillnaOptions, + IgnoreRaise, + IndexLabel, + Level, + PythonFuncType, + Renamer, + Scalar, + StorageOptions, + Suffixes, + WriteBuffer, +) +from pandas.core.common import apply_if_callable, is_bool_indexer +from pandas.core.dtypes.common import ( + infer_dtype_from_object, + is_bool_dtype, + is_dict_like, + is_list_like, + is_numeric_dtype, +) +from pandas.core.dtypes.inference import is_hashable, is_integer +from pandas.core.indexes.frozen import FrozenList +from pandas.io.formats.printing import pprint_thing +from pandas.util._validators import validate_bool_kwarg + +from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor +from snowflake.snowpark.modin.pandas.groupby import ( + DataFrameGroupBy, + validate_groupby_args, +) +from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( + SnowparkPandasRowPartitionIterator, +) +from snowflake.snowpark.modin.pandas.utils import ( + create_empty_native_pandas_frame, + from_non_pandas, + from_pandas, + is_scalar, + raise_if_native_pandas_objects, + replace_external_data_keys_with_empty_pandas_series, + replace_external_data_keys_with_query_compiler, +) +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + is_snowflake_agg_func, +) +from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated +from snowflake.snowpark.modin.plugin._typing import ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ( + ErrorMessage, + dataframe_not_implemented, +) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import ( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE, + DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE, + DF_SETITEM_SLICE_AS_SCALAR_VALUE, +) +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + _inherit_docstrings, + hashable, + validate_int_kwarg, +) +from snowflake.snowpark.udf import UserDefinedFunction + + +def register_dataframe_not_implemented(): + def decorator(base_method: Any): + func = dataframe_not_implemented()(base_method) + register_dataframe_accessor(base_method.__name__)(func) + return func + + return decorator + + +# === UNIMPLEMENTED METHODS === +# The following methods are not implemented in Snowpark pandas, and must be overridden on the +# frontend. These methods fall into a few categories: +# 1. Would work in Snowpark pandas, but we have not tested it. +# 2. Would work in Snowpark pandas, but requires more SQL queries than we are comfortable with. +# 3. Requires materialization (usually via a frontend _default_to_pandas call). +# 4. Performs operations on a native pandas Index object that are nontrivial for Snowpark pandas to manage. + + +# Avoid overwriting builtin `map` by accident +@register_dataframe_accessor("map") +@dataframe_not_implemented() +def _map(self, func, na_action: str | None = None, **kwargs) -> DataFrame: + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def boxplot( + self, + column=None, + by=None, + ax=None, + fontsize=None, + rot=0, + grid=True, + figsize=None, + layout=None, + return_type=None, + backend=None, + **kwargs, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def combine( + self, other, func, fill_value=None, overwrite=True +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def corrwith( + self, other, axis=0, drop=False, method="pearson", numeric_only=False +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def cov( + self, min_periods=None, ddof: int | None = 1, numeric_only=False +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def dot(self, other): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def hist( + self, + column=None, + by=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax=None, + sharex=False, + sharey=False, + figsize=None, + layout=None, + bins=10, + **kwds, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def isetitem(self, loc, value): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def prod( + self, + axis=None, + skipna=True, + numeric_only=False, + min_count=0, + **kwargs, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +register_dataframe_accessor("product")(prod) + + +@register_dataframe_not_implemented() +def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def reindex_like( + self, + other, + method=None, + copy: bool | None = None, + limit=None, + tolerance=None, +) -> DataFrame: # pragma: no cover + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_gbq( + self, + destination_table, + project_id=None, + chunksize=None, + reauth=False, + if_exists="fail", + auth_local_webserver=True, + table_schema=None, + location=None, + progress_bar=True, + credentials=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_html( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + max_rows=None, + max_cols=None, + show_dimensions=False, + decimal=".", + bold_rows=True, + classes=None, + escape=True, + notebook=False, + border=None, + table_id=None, + render_links=False, + encoding=None, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_parquet( + self, + path=None, + engine="auto", + compression="snappy", + index=None, + partition_cols=None, + storage_options: StorageOptions = None, + **kwargs, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_period( + self, freq=None, axis=0, copy=True +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_records( + self, index=True, column_dtypes=None, index_dtypes=None +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_stata( + self, + path: FilePath | WriteBuffer[bytes], + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime.datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + version: int | None = 114, + convert_strl: Sequence[Hashable] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_xml( + self, + path_or_buffer=None, + index=True, + root_name="data", + row_name="row", + na_rep=None, + attr_cols=None, + elem_cols=None, + namespaces=None, + prefix=None, + encoding="utf-8", + xml_declaration=True, + pretty_print=True, + parser="lxml", + stylesheet=None, + compression="infer", + storage_options=None, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __delitem__(self, key): + pass # pragma: no cover + + +@register_dataframe_accessor("attrs") +@dataframe_not_implemented() +@property +def attrs(self): # noqa: RT01, D200 + pass # pragma: no cover + + +@register_dataframe_accessor("style") +@dataframe_not_implemented() +@property +def style(self): # noqa: RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __reduce__(self): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __divmod__(self, other): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __rdivmod__(self, other): + pass # pragma: no cover + + +# The from_dict and from_records accessors are class methods and cannot be overridden via the +# extensions module, as they need to be foisted onto the namespace directly because they are not +# routed through getattr. To this end, we manually set DataFrame.from_dict to our new method. +@dataframe_not_implemented() +def from_dict( + cls, data, orient="columns", dtype=None, columns=None +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +DataFrame.from_dict = from_dict + + +@dataframe_not_implemented() +def from_records( + cls, + data, + index=None, + exclude=None, + columns=None, + coerce_float=False, + nrows=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +DataFrame.from_records = from_records + + +# === OVERRIDDEN METHODS === +# The below methods have their frontend implementations overridden compared to the version present +# in series.py. This is usually for one of the following reasons: +# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate +# and binary operations; further work is needed to refactor either our implementation or upstream +# modin's implementation. +# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already +# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details. +# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler +# layer is acceptable because we can force the method to raise NotImplementedError, but if a method +# defaults at the frontend, Modin raises a warning and performs the operation by coercing the +# dataset to a native pandas object. Removing these is tracked by +# https://github.com/modin-project/modin/issues/7104 + + +# Snowpark pandas overrides the constructor for two reasons: +# 1. To support the Snowpark pandas lazy index object +# 2. To avoid raising "UserWarning: Distributing object. This may take some time." +# when a literal is passed in as data. +@register_dataframe_accessor("__init__") +def __init__( + self, + data=None, + index=None, + columns=None, + dtype=None, + copy=None, + query_compiler=None, +) -> None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Siblings are other dataframes that share the same query compiler. We + # use this list to update inplace when there is a shallow copy. + from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native + + self._siblings = [] + + # Engine.subscribe(_update_engine) + if isinstance(data, (DataFrame, Series)): + self._query_compiler = data._query_compiler.copy() + if index is not None and any(i not in data.index for i in index): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if isinstance(data, Series): + # We set the column name if it is not in the provided Series + if data.name is None: + self.columns = [0] if columns is None else columns + # If the columns provided are not in the named Series, pandas clears + # the DataFrame and sets columns to the columns provided. + elif columns is not None and data.name not in columns: + self._query_compiler = from_pandas( + self.__constructor__(columns=columns) + )._query_compiler + if index is not None: + self._query_compiler = data.loc[index]._query_compiler + elif columns is None and index is None: + data._add_sibling(self) + else: + if columns is not None and any(i not in data.columns for i in columns): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if index is None: + index = slice(None) + if columns is None: + columns = slice(None) + self._query_compiler = data.loc[index, columns]._query_compiler + + # Check type of data and use appropriate constructor + elif query_compiler is None: + distributed_frame = from_non_pandas(data, index, columns, dtype) + if distributed_frame is not None: + self._query_compiler = distributed_frame._query_compiler + return + + if isinstance(data, native_pd.Index): + pass + elif is_list_like(data) and not is_dict_like(data): + old_dtype = getattr(data, "dtype", None) + values = [ + obj._to_pandas() if isinstance(obj, Series) else obj for obj in data + ] + if isinstance(data, np.ndarray): + data = np.array(values, dtype=old_dtype) + else: + try: + data = type(data)(values, dtype=old_dtype) + except TypeError: + data = values + elif is_dict_like(data) and not isinstance( + data, (native_pd.Series, Series, native_pd.DataFrame, DataFrame) + ): + if columns is not None: + data = {key: value for key, value in data.items() if key in columns} + + if len(data) and all(isinstance(v, Series) for v in data.values()): + from modin.pandas import concat + + new_qc = concat(data.values(), axis=1, keys=data.keys())._query_compiler + + if dtype is not None: + new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) + if index is not None: + new_qc = new_qc.reindex( + axis=0, labels=try_convert_index_to_native(index) + ) + if columns is not None: + new_qc = new_qc.reindex( + axis=1, labels=try_convert_index_to_native(columns) + ) + + self._query_compiler = new_qc + return + + data = { + k: v._to_pandas() if isinstance(v, Series) else v + for k, v in data.items() + } + pandas_df = native_pd.DataFrame( + data=try_convert_index_to_native(data), + index=try_convert_index_to_native(index), + columns=try_convert_index_to_native(columns), + dtype=dtype, + copy=copy, + ) + self._query_compiler = from_pandas(pandas_df)._query_compiler + else: + self._query_compiler = query_compiler + + +@register_dataframe_accessor("__dataframe__") +def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Get a Modin DataFrame that implements the dataframe exchange protocol. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + nan_as_null : bool, default: False + A keyword intended for the consumer to tell the producer + to overwrite null values in the data with ``NaN`` (or ``NaT``). + This currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + allow_copy : bool, default: True + A keyword that defines whether or not the library is allowed + to make a copy of the data. For example, copying data would be necessary + if a library supports strided buffers, given that this protocol + specifies contiguous buffers. Currently, if the flag is set to ``False`` + and a copy is needed, a ``RuntimeError`` will be raised. + + Returns + ------- + ProtocolDataframe + A dataframe object following the dataframe protocol specification. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented( + "Snowpark pandas does not support the DataFrame interchange " + + "protocol method `__dataframe__`. To use Snowpark pandas " + + "DataFrames with third-party libraries that try to call the " + + "`__dataframe__` method, please convert this Snowpark pandas " + + "DataFrame to pandas with `to_pandas()`." + ) + + return self._query_compiler.to_dataframe( + nan_as_null=nan_as_null, allow_copy=allow_copy + ) + + +# Snowpark pandas defaults to axis=1 instead of axis=0 for these; we need to investigate if the same should +# apply to upstream Modin. +@register_dataframe_accessor("__and__") +def __and__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__and__", other, axis=1) + + +@register_dataframe_accessor("__rand__") +def __rand__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__rand__", other, axis=1) + + +@register_dataframe_accessor("__or__") +def __or__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__or__", other, axis=1) + + +@register_dataframe_accessor("__ror__") +def __ror__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__ror__", other, axis=1) + + +# Upstream Modin defaults to pandas in some cases. +@register_dataframe_accessor("apply") +def apply( + self, + func: AggFuncType | UserDefinedFunction, + axis: Axis = 0, + raw: bool = False, + result_type: Literal["expand", "reduce", "broadcast"] | None = None, + args=(), + **kwargs, +): + """ + Apply a function along an axis of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.apply( + func, + axis, + raw=raw, + result_type=result_type, + args=args, + **kwargs, + ) + if not isinstance(query_compiler, type(self._query_compiler)): + # A scalar was returned + return query_compiler + + # If True, it is an unamed series. + # Theoretically, if df.apply returns a Series, it will only be an unnamed series + # because the function is supposed to be series -> scalar. + if query_compiler._modin_frame.is_unnamed_series(): + return Series(query_compiler=query_compiler) + else: + return self.__constructor__(query_compiler=query_compiler) + + +# Snowpark pandas uses a separate QC method, while modin directly calls map. +@register_dataframe_accessor("applymap") +def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not callable(func): + raise TypeError(f"{func} is not callable") + return self.__constructor__( + query_compiler=self._query_compiler.applymap( + func, na_action=na_action, **kwargs + ) + ) + + +# We need to override _get_columns to satisfy +# tests/unit/modin/test_type_annotations.py::test_properties_snow_1374293[_get_columns-type_hints1] +# since Modin doesn't provide this type hint. +def _get_columns(self) -> native_pd.Index: + """ + Get the columns for this Snowpark pandas ``DataFrame``. + + Returns + ------- + Index + The all columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.columns + + +# Snowpark pandas wraps this in an update_in_place +def _set_columns(self, new_columns: Axes) -> None: + """ + Set the columns for this Snowpark pandas ``DataFrame``. + + Parameters + ---------- + new_columns : + The new columns to set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + self._update_inplace( + new_query_compiler=self._query_compiler.set_columns(new_columns) + ) + + +register_dataframe_accessor("columns")(property(_get_columns, _set_columns)) + + +# Snowpark pandas does preprocessing for numeric_only (should be pushed to QC). +@register_dataframe_accessor("corr") +def corr( + self, + method: str | Callable = "pearson", + min_periods: int | None = None, + numeric_only: bool = False, +): # noqa: PR01, RT01, D200 + """ + Compute pairwise correlation of columns, excluding NA/null values. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + corr_df = self + if numeric_only: + corr_df = self.drop( + columns=[ + i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) + ] + ) + return self.__constructor__( + query_compiler=corr_df._query_compiler.corr( + method=method, + min_periods=min_periods, + ) + ) + + +# Snowpark pandas does not respect `ignore_index`, and upstream Modin does not respect `how`. +@register_dataframe_accessor("dropna") +def dropna( + self, + *, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, +): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super(DataFrame, self)._dropna( + axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace + ) + + +# Snowpark pandas uses `self_is_series`, while upstream Modin uses `squeeze_self` and `squeeze_value`. +@register_dataframe_accessor("fillna") +def fillna( + self, + value: Hashable | Mapping | Series | DataFrame = None, + *, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, +) -> DataFrame | None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super(DataFrame, self).fillna( + self_is_series=False, + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + + +# Snowpark pandas does different validation and returns a custom GroupBy object. +@register_dataframe_accessor("groupby") +def groupby( + self, + by=None, + axis: Axis | NoDefault = no_default, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool | NoDefault = no_default, + dropna: bool = True, +): + """ + Group ``DataFrame`` using a mapper or by a ``Series`` of columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if axis is not no_default: + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.groupby with axis=1 is deprecated. Do " + + "`frame.T.groupby(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + "The 'axis' keyword in DataFrame.groupby is deprecated and " + + "will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + validate_groupby_args(by, level, observed) + + axis = self._get_axis_number(axis) + + if axis != 0 and as_index is False: + raise ValueError("as_index=False only valid for axis=0") + + idx_name = None + + if ( + not isinstance(by, Series) + and is_list_like(by) + and len(by) == 1 + # if by is a list-like of (None,), we have to keep it as a list because + # None may be referencing a column or index level whose label is + # `None`, and by=None wold mean that there is no `by` param. + and by[0] is not None + ): + by = by[0] + + if hashable(by) and ( + not callable(by) and not isinstance(by, (native_pd.Grouper, FrozenList)) + ): + idx_name = by + elif isinstance(by, Series): + idx_name = by.name + if by._parent is self: + # if the SnowSeries comes from the current dataframe, + # convert it to labels directly for easy processing + by = by.name + elif is_list_like(by): + if axis == 0 and all( + ( + (hashable(o) and (o in self)) + or isinstance(o, Series) + or (is_list_like(o) and len(o) == len(self.shape[axis])) + ) + for o in by + ): + # plit 'by's into those that belongs to the self (internal_by) + # and those that doesn't (external_by). For SnowSeries that belongs + # to current DataFrame, we convert it to labels for easy process. + internal_by, external_by = [], [] + + for current_by in by: + if hashable(current_by): + internal_by.append(current_by) + elif isinstance(current_by, Series): + if current_by._parent is self: + internal_by.append(current_by.name) + else: + external_by.append(current_by) # pragma: no cover + else: + external_by.append(current_by) + + by = internal_by + external_by + + return DataFrameGroupBy( + self, + by, + axis, + level, + as_index, + sort, + group_keys, + idx_name, + observed=observed, + dropna=dropna, + ) + + +# Upstream Modin uses a proxy DataFrameInfo object +@register_dataframe_accessor("info") +def info( + self, + verbose: bool | None = None, + buf: IO[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool | None = None, + null_counts: bool | None = None, +): # noqa: PR01, D200 + """ + Print a concise summary of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def put_str(src, output_len=None, spaces=2): + src = str(src) + return src.ljust(output_len if output_len else len(src)) + " " * spaces + + def format_size(num): + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if num < 1024.0: + return f"{num:3.1f} {x}" + num /= 1024.0 + return f"{num:3.1f} PB" + + output = [] + + type_line = str(type(self)) + index_line = "SnowflakeIndex" + columns = self.columns + columns_len = len(columns) + dtypes = self.dtypes + dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" + + if max_cols is None: + max_cols = 100 + + exceeds_info_cols = columns_len > max_cols + + if buf is None: + buf = sys.stdout + + if null_counts is None: + null_counts = not exceeds_info_cols + + if verbose is None: + verbose = not exceeds_info_cols + + if null_counts and verbose: + # We're gonna take items from `non_null_count` in a loop, which + # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here + # that will be faster. + non_null_count = self.count()._to_pandas() + + if memory_usage is None: + memory_usage = True + + def get_header(spaces=2): + output = [] + head_label = " # " + column_label = "Column" + null_label = "Non-Null Count" + dtype_label = "Dtype" + non_null_label = " non-null" + delimiter = "-" + + lengths = {} + lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) + lengths["column"] = max( + len(column_label), max(len(pprint_thing(col)) for col in columns) + ) + lengths["dtype"] = len(dtype_label) + dtype_spaces = ( + max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) + - lengths["dtype"] + ) + + header = put_str(head_label, lengths["head"]) + put_str( + column_label, lengths["column"] + ) + if null_counts: + lengths["null"] = max( + len(null_label), + max(len(pprint_thing(x)) for x in non_null_count) + len(non_null_label), + ) + header += put_str(null_label, lengths["null"]) + header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) + + output.append(header) + + delimiters = put_str(delimiter * lengths["head"]) + put_str( + delimiter * lengths["column"] + ) + if null_counts: + delimiters += put_str(delimiter * lengths["null"]) + delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) + output.append(delimiters) + + return output, lengths + + output.extend([type_line, index_line]) + + def verbose_repr(output): + columns_line = f"Data columns (total {len(columns)} columns):" + header, lengths = get_header() + output.extend([columns_line, *header]) + for i, col in enumerate(columns): + i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) + + to_append = put_str(f" {i}", lengths["head"]) + put_str( + col_s, lengths["column"] + ) + if null_counts: + non_null = pprint_thing(non_null_count[col]) + to_append += put_str(f"{non_null} non-null", lengths["null"]) + to_append += put_str(dtype, lengths["dtype"], spaces=0) + output.append(to_append) + + def non_verbose_repr(output): + output.append(columns._summary(name="Columns")) + + if verbose: + verbose_repr(output) + else: + non_verbose_repr(output) + + output.append(dtypes_line) + + if memory_usage: + deep = memory_usage == "deep" + mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() + mem_line = f"memory usage: {format_size(mem_usage_bytes)}" + + output.append(mem_line) + + output.append("") + buf.write("\n".join(output)) + + +# Snowpark pandas does different validation. +@register_dataframe_accessor("insert") +def insert( + self, + loc: int, + column: Hashable, + value: Scalar | AnyArrayLike, + allow_duplicates: bool | NoDefault = no_default, +) -> None: + """ + Insert column into ``DataFrame`` at specified location. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + raise_if_native_pandas_objects(value) + if allow_duplicates is no_default: + allow_duplicates = False + if not allow_duplicates and column in self.columns: + raise ValueError(f"cannot insert {column}, already exists") + + if not isinstance(loc, int): + raise TypeError("loc must be int") + + # If columns labels are multilevel, we implement following behavior (this is + # name native pandas): + # Case 1: if 'column' is tuple it's length must be same as number of levels + # otherwise raise error. + # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in + # empty strings to match the length of column levels in self frame. + if self.columns.nlevels > 1: + if isinstance(column, tuple) and len(column) != self.columns.nlevels: + # same error as native pandas. + raise ValueError("Item must have length equal to number of levels.") + if not isinstance(column, tuple): + # Fill empty strings to match length of levels + suffix = [""] * (self.columns.nlevels - 1) + column = tuple([column] + suffix) + + # Dictionary keys are treated as index column and this should be joined with + # index of target dataframe. This behavior is similar to 'value' being DataFrame + # or Series, so we simply create Series from dict data here. + if isinstance(value, dict): + value = Series(value, name=column) + + if isinstance(value, DataFrame) or ( + isinstance(value, np.ndarray) and len(value.shape) > 1 + ): + # Supported numpy array shapes are + # 1. (N, ) -> Ex. [1, 2, 3] + # 2. (N, 1) -> Ex> [[1], [2], [3]] + if value.shape[1] != 1: + if isinstance(value, DataFrame): + # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin + raise ValueError( + f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." + ) + else: + raise ValueError( + f"Expected a 1D array, got an array with shape {value.shape}" + ) + # Change numpy array shape from (N, 1) to (N, ) + if isinstance(value, np.ndarray): + value = value.squeeze(axis=1) + + if ( + is_list_like(value) + and not isinstance(value, (Series, DataFrame)) + and len(value) != self.shape[0] + and not 0 == self.shape[0] # dataframe holds no rows + ): + raise ValueError( + "Length of values ({}) does not match length of index ({})".format( + len(value), len(self) + ) + ) + if not -len(self.columns) <= loc <= len(self.columns): + raise IndexError( + f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" + ) + elif loc < 0: + raise ValueError("unbounded slice") + + join_on_index = False + if isinstance(value, (Series, DataFrame)): + value = value._query_compiler + join_on_index = True + elif is_list_like(value): + value = Series(value, name=column)._query_compiler + + new_query_compiler = self._query_compiler.insert(loc, column, value, join_on_index) + # In pandas, 'insert' operation is always inplace. + self._update_inplace(new_query_compiler=new_query_compiler) + + +# Snowpark pandas does more specialization based on the type of `values` +@register_dataframe_accessor("isin") +def isin( + self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] +) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(values, dict): + return super(DataFrame, self).isin(values) + elif isinstance(values, Series): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not values.index.is_unique: + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + elif isinstance(values, DataFrame): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not (values.columns.is_unique and values.index.is_unique): + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + else: + if not is_list_like(values): + # throw pandas compatible error + raise TypeError( + "only list-like or dict-like objects are allowed " + f"to be passed to {self.__class__.__name__}.isin(), " + f"you passed a '{type(values).__name__}'" + ) + return super(DataFrame, self).isin(values) + + +# Upstream Modin defaults to pandas for some arguments. +@register_dataframe_accessor("join") +def join( + self, + other: DataFrame | Series | Iterable[DataFrame | Series], + on: IndexLabel | None = None, + how: str = "left", + lsuffix: str = "", + rsuffix: str = "", + sort: bool = False, + validate: str | None = None, +) -> DataFrame: + """ + Join columns of another ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + for o in other if isinstance(other, list) else [other]: + raise_if_native_pandas_objects(o) + + # Similar to native pandas we implement 'join' using 'pd.merge' method. + # Following code is copied from native pandas (with few changes explained below) + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 + if isinstance(other, Series): + # Same error as native pandas. + if other.name is None: + raise ValueError("Other Series must have a name") + other = DataFrame(other) + elif is_list_like(other): + if any([isinstance(o, Series) and o.name is None for o in other]): + raise ValueError("Other Series must have a name") + + if isinstance(other, DataFrame): + if how == "cross": + return pd.merge( + self, + other, + how=how, + on=on, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + return pd.merge( + self, + other, + left_on=on, + how=how, + left_index=on is None, + right_index=True, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + else: # List of DataFrame/Series + # Same error as native pandas. + if on is not None: + raise ValueError( + "Joining multiple DataFrames only supported for joining on index" + ) + + # Same error as native pandas. + if rsuffix or lsuffix: + raise ValueError("Suffixes not supported when joining multiple DataFrames") + + # NOTE: These are not the differences between Snowpark pandas API and pandas behavior + # these are differences between native pandas join behavior when join + # frames have unique index or not. + + # In native pandas logic to join multiple DataFrames/Series is data + # dependent. Under the hood it will either use 'concat' or 'merge' API + # Case 1. If all objects being joined have unique index use 'concat' (axis=1) + # Case 2. Otherwise use 'merge' API by looping through objects left to right. + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 + + # Even though concat (axis=1) and merge are very similar APIs they have + # some differences which leads to inconsistent behavior in native pandas. + # 1. Treatment of un-named Series + # Case #1: Un-named series is allowed in concat API. Objects are joined + # successfully by assigning a number as columns name (see 'concat' API + # documentation for details on treatment of un-named series). + # Case #2: It raises 'ValueError: Other Series must have a name' + + # 2. how='right' + # Case #1: 'concat' API doesn't support right join. It raises + # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' + # Case #2: Merges successfully. + + # 3. Joining frames with duplicate labels but no conflict with other frames + # Example: self = DataFrame(... columns=["A", "B"]) + # other = [DataFrame(... columns=["C", "C"])] + # Case #1: 'ValueError: Indexes have overlapping values' + # Case #2: Merged successfully. -from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor -from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( - is_snowflake_agg_func, -) -from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage -from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage -from snowflake.snowpark.modin.utils import _inherit_docstrings, validate_int_kwarg + # In addition to this, native pandas implementation also leads to another + # type of inconsistency where left.join(other, ...) and + # left.join([other], ...) might behave differently for cases mentioned + # above. + # Example: + # import pandas as pd + # df = pd.DataFrame({"a": [4, 5]}) + # other = pd.Series([1, 2]) + # df.join([other]) # this is successful + # df.join(other) # this raises 'ValueError: Other Series must have a name' + + # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API + # to join multiple DataFrame/Series. So always follow the behavior + # documented as Case #2 above. + + joined = self + for frame in other: + if isinstance(frame, DataFrame): + overlapping_cols = set(joined.columns).intersection(set(frame.columns)) + if len(overlapping_cols) > 0: + # Native pandas raises: 'Indexes have overlapping values' + # We differ slightly from native pandas message to make it more + # useful to users. + raise ValueError( + f"Join dataframes have overlapping column labels: {overlapping_cols}" + ) + joined = pd.merge( + joined, + frame, + how=how, + left_index=True, + right_index=True, + validate=validate, + sort=sort, + suffixes=(None, None), + ) + return joined + + +# Snowpark pandas does extra error checking. +@register_dataframe_accessor("mask") +def mask( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.mask requires an axis parameter (0 or 1) when given a Series" + ) + + return super(DataFrame, self).mask( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + +# Snowpark pandas has a fix for a pandas behavior change. It is available in Modin 0.30.1 (SNOW-1552497). +@register_dataframe_accessor("melt") +def melt( + self, + id_vars=None, + value_vars=None, + var_name=None, + value_name="value", + col_level=None, + ignore_index=True, +): # noqa: PR01, RT01, D200 + """ + Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if id_vars is None: + id_vars = [] + if not is_list_like(id_vars): + id_vars = [id_vars] + if value_vars is None: + # Behavior of Index.difference changed in 2.2.x + # https://github.com/pandas-dev/pandas/pull/55113 + # This change needs upstream to Modin: + # https://github.com/modin-project/modin/issues/7206 + value_vars = self.columns.drop(id_vars) + if var_name is None: + columns_name = self._query_compiler.get_index_name(axis=1) + var_name = columns_name if columns_name is not None else "variable" + return self.__constructor__( + query_compiler=self._query_compiler.melt( + id_vars=id_vars, + value_vars=value_vars, + var_name=var_name, + value_name=value_name, + col_level=col_level, + ignore_index=ignore_index, + ) + ) + + +# Snowpark pandas does more thorough error checking. +@register_dataframe_accessor("merge") +def merge( + self, + right: DataFrame | Series, + how: str = "inner", + on: IndexLabel | None = None, + left_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None, + right_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool = True, + indicator: bool = False, + validate: str | None = None, +) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Raise error if native pandas objects are passed. + raise_if_native_pandas_objects(right) + + if isinstance(right, Series) and right.name is None: + raise ValueError("Cannot merge a Series without a name") + if not isinstance(right, (Series, DataFrame)): + raise TypeError( + f"Can only merge Series or DataFrame objects, a {type(right)} was passed" + ) + + if isinstance(right, Series): + right_column_nlevels = len(right.name) if isinstance(right.name, tuple) else 1 + else: + right_column_nlevels = right.columns.nlevels + if self.columns.nlevels != right_column_nlevels: + # This is deprecated in native pandas. We raise explicit error for this. + raise ValueError( + "Can not merge objects with different column levels." + + f" ({self.columns.nlevels} levels on the left," + + f" {right_column_nlevels} on the right)" + ) + + # Merge empty native pandas dataframes for error checking. Otherwise, it will + # require a lot of logic to be written. This takes care of raising errors for + # following scenarios: + # 1. Only 'left_index' is set to True. + # 2. Only 'right_index is set to True. + # 3. Only 'left_on' is provided. + # 4. Only 'right_on' is provided. + # 5. 'on' and 'left_on' both are provided + # 6. 'on' and 'right_on' both are provided + # 7. 'on' and 'left_index' both are provided + # 8. 'on' and 'right_index' both are provided + # 9. 'left_on' and 'left_index' both are provided + # 10. 'right_on' and 'right_index' both are provided + # 11. Length mismatch between 'left_on' and 'right_on' + # 12. 'left_index' is not a bool + # 13. 'right_index' is not a bool + # 14. 'on' is not None and how='cross' + # 15. 'left_on' is not None and how='cross' + # 16. 'right_on' is not None and how='cross' + # 17. 'left_index' is True and how='cross' + # 18. 'right_index' is True and how='cross' + # 19. Unknown label in 'on', 'left_on' or 'right_on' + # 20. Provided 'suffixes' is not sufficient to resolve conflicts. + # 21. Merging on column with duplicate labels. + # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} + # 23. conflict with existing labels for array-like join key + # 24. 'indicator' argument is not bool or str + # 25. indicator column label conflicts with existing data labels + create_empty_native_pandas_frame(self).merge( + create_empty_native_pandas_frame(right), + on=on, + how=how, + left_on=replace_external_data_keys_with_empty_pandas_series(left_on), + right_on=replace_external_data_keys_with_empty_pandas_series(right_on), + left_index=left_index, + right_index=right_index, + suffixes=suffixes, + indicator=indicator, + ) + + return self.__constructor__( + query_compiler=self._query_compiler.merge( + right._query_compiler, + how=how, + on=on, + left_on=replace_external_data_keys_with_query_compiler(self, left_on), + right_on=replace_external_data_keys_with_query_compiler(right, right_on), + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + copy=copy, + indicator=indicator, + validate=validate, + ) + ) @_inherit_docstrings(native_pd.DataFrame.memory_usage, apilink="pandas.DataFrame") @@ -62,6 +1485,125 @@ def memory_usage(self, index: bool = True, deep: bool = False) -> Any: return native_pd.Series([0] * len(columns), index=columns) +# Snowpark pandas handles `inplace` differently. +@register_dataframe_accessor("replace") +def replace( + self, + to_replace=None, + value=no_default, + inplace: bool = False, + limit=None, + regex: bool = False, + method: str | NoDefault = no_default, +): + """ + Replace values given in `to_replace` with `value`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + new_query_compiler = self._query_compiler.replace( + to_replace=to_replace, + value=value, + limit=limit, + regex=regex, + method=method, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Snowpark pandas interacts with the inplace flag differently. +@register_dataframe_accessor("rename") +def rename( + self, + mapper: Renamer | None = None, + *, + index: Renamer | None = None, + columns: Renamer | None = None, + axis: Axis | None = None, + copy: bool | None = None, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", +) -> DataFrame | None: + """ + Alter axes labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if mapper is None and index is None and columns is None: + raise TypeError("must pass an index to rename") + + if index is not None or columns is not None: + if axis is not None: + raise TypeError( + "Cannot specify both 'axis' and any of 'index' or 'columns'" + ) + elif mapper is not None: + raise TypeError( + "Cannot specify both 'mapper' and any of 'index' or 'columns'" + ) + else: + # use the mapper argument + if axis and self._get_axis_number(axis) == 1: + columns = mapper + else: + index = mapper + + if copy is not None: + WarningMessage.ignored_argument( + operation="dataframe.rename", + argument="copy", + message="copy parameter has been ignored with Snowflake execution engine", + ) + + if isinstance(index, dict): + index = Series(index) + + new_qc = self._query_compiler.rename( + index_renamer=index, columns_renamer=columns, level=level, errors=errors + ) + return self._create_or_update_from_compiler( + new_query_compiler=new_qc, inplace=inplace + ) + + +# Upstream modin converts aggfunc to a cython function if it's a string. +@register_dataframe_accessor("pivot_table") +def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc="mean", + fill_value=None, + margins=False, + dropna=True, + margins_name="All", + observed=False, + sort=True, +): + """ + Create a spreadsheet-style pivot table as a ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + result = self.__constructor__( + query_compiler=self._query_compiler.pivot_table( + index=index, + values=values, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + observed=observed, + sort=sort, + ) + ) + return result + + +# Snowpark pandas produces a different warning for materialization. @register_dataframe_accessor("plot") @property def plot( @@ -108,11 +1650,227 @@ def plot( return self._to_pandas().plot +# Upstream Modin defaults when other is a Series. +@register_dataframe_accessor("pow") +def pow( + self, other, axis="columns", level=None, fill_value=None +): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "pow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + +@register_dataframe_accessor("rpow") +def rpow( + self, other, axis="columns", level=None, fill_value=None +): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rpow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + +# Snowpark pandas does extra argument validation, and uses iloc instead of drop at the end. +@register_dataframe_accessor("select_dtypes") +def select_dtypes( + self, + include: ListLike | str | type | None = None, + exclude: ListLike | str | type | None = None, +) -> DataFrame: + """ + Return a subset of the ``DataFrame``'s columns based on the column dtypes. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This line defers argument validation to pandas, which will raise errors on our behalf in cases + # like if `include` and `exclude` are None, the same type is specified in both lists, or a string + # dtype (as opposed to object) is specified. + native_pd.DataFrame().select_dtypes(include, exclude) + + if include and not is_list_like(include): + include = [include] + elif include is None: + include = [] + if exclude and not is_list_like(exclude): + exclude = [exclude] + elif exclude is None: + exclude = [] + + sel = tuple(map(set, (include, exclude))) + + # The width of the np.int_/float_ alias differs between Windows and other platforms, so + # we need to include a workaround. + # https://github.com/numpy/numpy/issues/9464 + # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 + def check_sized_number_infer_dtypes(dtype): + if (isinstance(dtype, str) and dtype == "int") or (dtype is int): + return [np.int32, np.int64] + elif dtype == "float" or dtype is float: + return [np.float64, np.float32] + else: + return [infer_dtype_from_object(dtype)] + + include, exclude = map( + lambda x: set( + itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) + ), + sel, + ) + # We need to index on column position rather than label in case of duplicates + include_these = native_pd.Series(not bool(include), index=range(len(self.columns))) + exclude_these = native_pd.Series(not bool(exclude), index=range(len(self.columns))) + + def is_dtype_instance_mapper(dtype): + return functools.partial(issubclass, dtype.type) + + for i, dtype in enumerate(self.dtypes): + if include: + include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) + if exclude: + exclude_these[i] = not any(map(is_dtype_instance_mapper(dtype), exclude)) + + dtype_indexer = include_these & exclude_these + indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] + # We need to use iloc instead of drop in case of duplicate column names + return self.iloc[:, indicate] + + +# Snowpark pandas does extra validation on the `axis` argument. +@register_dataframe_accessor("set_axis") +def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, # ignored +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not is_scalar(axis): + raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") + return super(DataFrame, self).set_axis( + labels=labels, + # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. + axis=native_pd.DataFrame._get_axis_name(axis), + copy=copy, + ) + + +# Snowpark pandas needs extra logic for the lazy index class. +@register_dataframe_accessor("set_index") +def set_index( + self, + keys: IndexLabel + | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], + drop: bool = True, + append: bool = False, + inplace: bool = False, + verify_integrity: bool = False, +) -> None | DataFrame: + """ + Set the ``DataFrame`` index using existing columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if not isinstance(keys, list): + keys = [keys] + + # make sure key is either hashable, index, or series + label_or_series = [] + + missing = [] + columns = self.columns.tolist() + for key in keys: + raise_if_native_pandas_objects(key) + if isinstance(key, pd.Series): + label_or_series.append(key._query_compiler) + elif isinstance(key, (np.ndarray, list, Iterator)): + label_or_series.append(pd.Series(key)._query_compiler) + elif isinstance(key, (pd.Index, native_pd.MultiIndex)): + label_or_series += [s._query_compiler for s in self._to_series_list(key)] + else: + if not is_hashable(key): + raise TypeError( + f'The parameter "keys" may be a column key, one-dimensional array, or a list ' + f"containing only valid column keys and one-dimensional arrays. Received column " + f"of type {type(key)}" + ) + label_or_series.append(key) + found = key in columns + if columns.count(key) > 1: + raise ValueError(f"The column label '{key}' is not unique") + elif not found: + missing.append(key) + + if missing: + raise KeyError(f"None of {missing} are in the columns") + + new_query_compiler = self._query_compiler.set_index( + label_or_series, drop=drop, append=append + ) + + # TODO: SNOW-782633 improve this code once duplicate is supported + # this needs to pull all index which is inefficient + if verify_integrity and not new_query_compiler.index.is_unique: + duplicates = new_query_compiler.index[ + new_query_compiler.index.to_pandas().duplicated() + ].unique() + raise ValueError(f"Index has duplicate keys: {duplicates}") + + return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) + + +# Upstream Modin uses `len(self.index)` instead of `len(self)`, which gives an extra query. +@register_dataframe_accessor("shape") +@property +def shape(self) -> tuple[int, int]: + """ + Return a tuple representing the dimensionality of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return len(self), len(self.columns) + + +# Snowpark pands has rewrites to minimize queries from length checks. +@register_dataframe_accessor("squeeze") +def squeeze(self, axis: Axis | None = None): + """ + Squeeze 1 dimensional axis objects into scalars. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) if axis is not None else None + len_columns = self._query_compiler.get_axis_len(1) + if axis == 1 and len_columns == 1: + return Series(query_compiler=self._query_compiler) + if axis in [0, None]: + # get_axis_len(0) results in a sql query to count number of rows in current + # dataframe. We should only compute len_index if axis is 0 or None. + len_index = len(self) + if axis is None and (len_columns == 1 or len_index == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 0 and len_index == 1: + return Series(query_compiler=self.T._query_compiler) + return self.copy() + + # Upstream modin defines sum differently for series/DF, but we use the same implementation for both. @register_dataframe_accessor("sum") def sum( self, - axis: Union[Axis, None] = None, + axis: Axis | None = None, skipna: bool = True, numeric_only: bool = False, min_count: int = 0, @@ -130,6 +1888,70 @@ def sum( ) +# Snowpark pandas raises a warning where modin defaults to pandas. +@register_dataframe_accessor("stack") +def stack( + self, + level: int | str | list = -1, + dropna: bool | NoDefault = no_default, + sort: bool | NoDefault = no_default, + future_stack: bool = False, # ignored +): + """ + Stack the prescribed level(s) from columns to index. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if future_stack is not False: + WarningMessage.ignored_argument( # pragma: no cover + operation="DataFrame.stack", + argument="future_stack", + message="future_stack parameter has been ignored with Snowflake execution engine", + ) + if dropna is NoDefault: + dropna = True # pragma: no cover + if sort is NoDefault: + sort = True # pragma: no cover + + # This ensures that non-pandas MultiIndex objects are caught. + is_multiindex = len(self.columns.names) > 1 + if not is_multiindex or ( + is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.stack(level, dropna, sort) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.stack(level, dropna, sort) + ) + + +# Upstream modin doesn't pass `copy`, so we can't raise a warning for it. +# No need to override the `T` property since that can't take any extra arguments. +@register_dataframe_accessor("transpose") +def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 + """ + Transpose index and columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if copy: + WarningMessage.ignored_argument( + operation="transpose", + argument="copy", + message="Transpose ignore copy argument in Snowpark pandas API", + ) + + if args: + WarningMessage.ignored_argument( + operation="transpose", + argument="args", + message="Transpose ignores args in Snowpark pandas API", + ) + + return self.__constructor__(query_compiler=self._query_compiler.transpose()) + + +# Upstream modin implements transform in base.py, but we don't yet support Series.transform. @register_dataframe_accessor("transform") def transform( self, func: PythonFuncType, axis: Axis = 0, *args: Any, **kwargs: Any @@ -151,3 +1973,380 @@ def transform( raise ValueError("Function did not transform") return self.apply(func, axis, False, args=args, **kwargs) + + +# Upstream modin defaults to pandas for some arguments. +@register_dataframe_accessor("unstack") +def unstack( + self, + level: int | str | list = -1, + fill_value: int | str | dict = None, + sort: bool = True, +): + """ + Pivot a level of the (necessarily hierarchical) index labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This ensures that non-pandas MultiIndex objects are caught. + nlevels = self._query_compiler.nlevels() + is_multiindex = nlevels > 1 + + if not is_multiindex or ( + is_multiindex and is_list_like(level) and len(level) == nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=False + ) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=False + ) + ) + + +# Upstream modin does different validation and sorting. +@register_dataframe_accessor("value_counts") +def value_counts( + self, + subset: Sequence[Hashable] | None = None, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + dropna: bool = True, +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return Series( + query_compiler=self._query_compiler.value_counts( + subset=subset, + normalize=normalize, + sort=sort, + ascending=ascending, + dropna=dropna, + ), + name="proportion" if normalize else "count", + ) + + +@register_dataframe_accessor("where") +def where( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.where requires an axis parameter (0 or 1) when given a Series" + ) + + return super(DataFrame, self).where( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + +# Snowpark pandas has a custom iterator. +@register_dataframe_accessor("iterrows") +def iterrows(self) -> Iterator[tuple[Hashable, Series]]: + """ + Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def iterrow_builder(s): + """Return tuple of the given `s` parameter name and the parameter themselves.""" + return s.name, s + + # Raise warning message since iterrows is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") + ) + + partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) + yield from partition_iterator + + +# Snowpark pandas has a custom iterator. +@register_dataframe_accessor("itertuples") +def itertuples( + self, index: bool = True, name: str | None = "Pandas" +) -> Iterable[tuple[Any, ...]]: + """ + Iterate over ``DataFrame`` rows as ``namedtuple``-s. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + + def itertuples_builder(s): + """Return the next namedtuple.""" + # s is the Series of values in the current row. + fields = [] # column names + data = [] # values under each column + + if index: + data.append(s.name) + fields.append("Index") + + # Fill column names and values. + fields.extend(list(self.columns)) + data.extend(s) + + if name is not None: + # Creating the namedtuple. + itertuple = collections.namedtuple(name, fields, rename=True) + return itertuple._make(data) + + # When the name is None, return a regular tuple. + return tuple(data) + + # Raise warning message since itertuples is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") + ) + return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) + + +# Snowpark pandas truncates the repr output. +@register_dataframe_accessor("__repr__") +def __repr__(self): + """ + Return a string representation for a particular ``DataFrame``. + + Returns + ------- + str + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + num_rows = native_pd.get_option("display.max_rows") or len(self) + # see _repr_html_ for comment, allow here also all column behavior + num_cols = native_pd.get_option("display.max_columns") or len(self.columns) + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") + result = repr(repr_df) + + # if truncated, add shape information + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # The split here is so that we don't repr pandas row lengths. + return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( + row_count, col_count + ) + else: + return result + + +# Snowpark pandas uses a different default `num_rows` value. +@register_dataframe_accessor("_repr_html_") +def _repr_html_(self): # pragma: no cover + """ + Return a html representation for a particular ``DataFrame``. + + Returns + ------- + str + + Notes + ----- + Supports pandas `display.max_rows` and `display.max_columns` options. + """ + num_rows = native_pd.get_option("display.max_rows") or 60 + # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow + # here value=0 which means display all columns. + num_cols = native_pd.get_option("display.max_columns") + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols) + result = repr_df._repr_html_() + + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # We split so that we insert our correct dataframe dimensions. + return ( + result.split("

")[0] + + f"

{row_count} rows × {col_count} columns

\n" + ) + else: + return result + + +# Upstream modin just uses `to_datetime` rather than `dataframe_to_datetime` on the query compiler. +@register_dataframe_accessor("_to_datetime") +def _to_datetime(self, **kwargs): + """ + Convert `self` to datetime. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_datetime` invocation. + + Returns + ------- + Series of datetime64 dtype + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._reduce_dimension( + query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) + ) + + +# Snowpark pandas has the extra `statement_params` argument. +@register_dataframe_accessor("_to_pandas") +def _to_pandas( + self, + *, + statement_params: dict[str, str] | None = None, + **kwargs: Any, +) -> native_pd.DataFrame: + """ + Convert Snowpark pandas DataFrame to pandas DataFrame + + Args: + statement_params: Dictionary of statement level parameters to be set while executing this action. + + Returns: + pandas DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.to_pandas(statement_params=statement_params, **kwargs) + + +# Snowpark pandas does more validation and error checking than upstream Modin, and uses different +# helper methods for dispatch. +@register_dataframe_accessor("__setitem__") +def __setitem__(self, key: Any, value: Any): + """ + Set attribute `value` identified by `key`. + + Args: + key: Key to set + value: Value to set + + Note: + In the case where value is any list like or array, pandas checks the array length against the number of rows + of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw + a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if + the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use + enlargement filling with the last value in the array. + + Returns: + None + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + key = apply_if_callable(key, self) + if isinstance(key, DataFrame) or ( + isinstance(key, np.ndarray) and len(key.shape) == 2 + ): + # This case uses mask's codepath to perform the set, but + # we need to duplicate the code here since we are passing + # an additional kwarg `cond_fillna_with_true` to the QC here. + # We need this additional kwarg, since if df.shape + # and key.shape do not align (i.e. df has more rows), + # mask's codepath would mask the additional rows in df + # while for setitem, we need to keep the original values. + if not isinstance(key, DataFrame): + if key.dtype != bool: + raise TypeError( + "Must pass DataFrame or 2-d ndarray with boolean values only" + ) + key = DataFrame(key) + key._query_compiler._shape_hint = "array" + + if value is not None: + value = apply_if_callable(value, self) + + if isinstance(value, np.ndarray): + value = DataFrame(value) + value._query_compiler._shape_hint = "array" + elif isinstance(value, pd.Series): + # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this + # error instead, since it is more descriptive. + raise ValueError( + "setitem with a 2D key does not support Series values." + ) + + if isinstance(value, BasePandasDataset): + value = value._query_compiler + + query_compiler = self._query_compiler.mask( + cond=key._query_compiler, + other=value, + axis=None, + level=None, + cond_fillna_with_true=True, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace=True) + + # Error Checking: + if (isinstance(key, pd.Series) or is_list_like(key)) and (isinstance(value, range)): + raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) + elif isinstance(value, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) + + # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column + # key. + index, columns = slice(None), key + index_is_bool_indexer = False + if isinstance(key, slice): + if is_integer(key.start) and is_integer(key.stop): + # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as + # df.iloc[1:2, :] = val + self.iloc[key] = value + return + index, columns = key, slice(None) + elif isinstance(key, pd.Series): + if is_bool_dtype(key.dtype): + index, columns = key, slice(None) + index_is_bool_indexer = True + elif is_bool_indexer(key): + index, columns = pd.Series(key), slice(None) + index_is_bool_indexer = True + + # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case + # we have to explicitly set matching_item_columns_by_label to False for setitem. + index = index._query_compiler if isinstance(index, BasePandasDataset) else index + columns = ( + columns._query_compiler if isinstance(columns, BasePandasDataset) else columns + ) + from snowflake.snowpark.modin.pandas.indexing import is_2d_array + + matching_item_rows_by_label = not is_2d_array(value) + if is_2d_array(value): + value = DataFrame(value) + item = value._query_compiler if isinstance(value, BasePandasDataset) else value + new_qc = self._query_compiler.set_2d_labels( + index, + columns, + item, + # setitem always matches item by position + matching_item_columns_by_label=False, + matching_item_rows_by_label=matching_item_rows_by_label, + index_is_bool_indexer=index_is_bool_indexer, + # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling + # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the + # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have + # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns + # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", + # "X", "X". + deduplicate_columns=True, + ) + return self._update_inplace(new_query_compiler=new_qc) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 12710224de7..643f6f5038e 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -30,6 +30,7 @@ import modin import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import get_option from pandas._libs import lib @@ -49,7 +50,6 @@ ) from pandas.core.dtypes.inference import is_hashable -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py index c5f9e4f6cee..43f9603cfb4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py @@ -9,10 +9,10 @@ import inspect from typing import Any, Iterable, Literal, Optional, Union +from modin.pandas import DataFrame, Series from pandas._typing import IndexLabel from snowflake.snowpark import DataFrame as SnowparkDataFrame -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py index dea98bbb0d3..6d6fb4cd0bd 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py @@ -15,6 +15,7 @@ ) import pandas as native_pd +from modin.pandas import DataFrame from pandas._libs.lib import NoDefault, no_default from pandas._typing import ( CSVEngine, @@ -26,7 +27,6 @@ ) import snowflake.snowpark.modin.pandas as pd -from snowflake.snowpark.modin.pandas import DataFrame from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py index f7bba4c743a..5b245bfdab4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py @@ -181,7 +181,6 @@ def to_pandas( See Also: - :func:`to_pandas ` - - :func:`DataFrame.to_pandas ` Returns: pandas Series diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 5011defa685..b104c223e26 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -9,22 +9,13 @@ from __future__ import annotations -from typing import ( - IO, - TYPE_CHECKING, - Any, - Callable, - Hashable, - Literal, - Mapping, - Sequence, -) +from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence import modin.pandas as pd import numpy as np import numpy.typing as npt import pandas as native_pd -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas._libs.lib import NoDefault, is_integer, no_default from pandas._typing import ( @@ -73,9 +64,6 @@ validate_int_kwarg, ) -if TYPE_CHECKING: - from modin.pandas import DataFrame - def register_series_not_implemented(): def decorator(base_method: Any): @@ -209,21 +197,6 @@ def hist( pass # pragma: no cover -@register_series_not_implemented() -def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, -): # noqa: PR01, RT01, D200 - pass # pragma: no cover - - @register_series_not_implemented() def item(self): # noqa: RT01, D200 pass # pragma: no cover @@ -1451,9 +1424,7 @@ def set_axis( ) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. +# Snowpark pandas does different validation. @register_series_accessor("rename") def rename( self, @@ -1503,9 +1474,36 @@ def rename( return self_cp -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. +# Modin defaults to pandas for some arguments for unstack +@register_series_accessor("unstack") +def unstack( + self, + level: int | str | list = -1, + fill_value: int | str | dict = None, + sort: bool = True, +): + """ + Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from modin.pandas.dataframe import DataFrame + + # We can't unstack a Series object, if we don't have a MultiIndex. + if self._query_compiler.has_multiindex: + result = DataFrame( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=True + ) + ) + else: + raise ValueError( # pragma: no cover + f"index must be a MultiIndex to unstack, {type(self.index)} was passed" + ) + + return result + + +# Snowpark pandas does an extra check on `len(ascending)`. @register_series_accessor("sort_values") def sort_values( self, @@ -1521,7 +1519,7 @@ def sort_values( Sort by the values. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame + from modin.pandas.dataframe import DataFrame if is_list_like(ascending) and len(ascending) != 1: raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series") @@ -1550,38 +1548,6 @@ def sort_values( return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. -# Modin also defaults to pandas for some arguments for unstack -@register_series_accessor("unstack") -def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, -): - """ - Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - # We can't unstack a Series object, if we don't have a MultiIndex. - if self._query_compiler.has_multiindex: - result = DataFrame( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=True - ) - ) - else: - raise ValueError( # pragma: no cover - f"index must be a MultiIndex to unstack, {type(self.index)} was passed" - ) - - return result - - # Upstream Modin defaults at the frontend layer. @register_series_accessor("where") def where( @@ -1727,63 +1693,6 @@ def to_dict(self, into: type[dict] = dict) -> dict: return self._to_pandas().to_dict(into=into) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored -# version of DataFrame, we must keep this override. -@register_series_accessor("_create_or_update_from_compiler") -def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a Series with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - Series, DataFrame or None - None if update was done, Series or DataFrame otherwise. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace and new_query_compiler.is_series_like(): - return self.__constructor__(query_compiler=new_query_compiler) - elif not inplace: - # This can happen with things like `reset_index` where we can add columns. - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - return DataFrame(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored -# version of DataFrame, we must keep this override. -@register_series_accessor("to_frame") -def to_frame(self, name: Hashable = no_default) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Convert Series to {label -> value} dict or dict-like object. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - if name is None: - name = no_default - - self_cp = self.copy() - if name is not no_default: - self_cp.name = name - - return DataFrame(self_cp) - - @register_series_accessor("to_numpy") def to_numpy( self, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 9cb4ffa7327..1cd5e31c63f 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -28,11 +28,11 @@ import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from pandas._libs import lib from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable from pandas.core.dtypes.common import is_timedelta64_dtype -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py index 785a492ca89..f3102115a32 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py +++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py @@ -42,3 +42,17 @@ SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE = ( "Scalar key incompatible with {} value" ) + +DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( + "Currently do not support Series or list-like keys with range-like values" +) + +DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( + "Currently do not support assigning a slice value as if it's a scalar value" +) + +DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( + "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " + "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " + "can work on the entire DataFrame in one shot." +) diff --git a/src/snowflake/snowpark/modin/utils.py b/src/snowflake/snowpark/modin/utils.py index b1027f00e33..b3446ca0362 100644 --- a/src/snowflake/snowpark/modin/utils.py +++ b/src/snowflake/snowpark/modin/utils.py @@ -1171,7 +1171,7 @@ def validate_int_kwarg(value: int, arg_name: str, float_allowed: bool = False) - def doc_replace_dataframe_with_link(_obj: Any, doc: str) -> str: """ Helper function to be passed as the `modify_doc` parameter to `_inherit_docstrings`. This replaces - all unqualified instances of "DataFrame" with ":class:`~snowflake.snowpark.pandas.DataFrame`" to + all unqualified instances of "DataFrame" with ":class:`~modin.pandas.DataFrame`" to prevent it from linking automatically to snowflake.snowpark.DataFrame: see SNOW-1233342. To prevent it from overzealously replacing examples in doctests or already-qualified paths, it diff --git a/tests/integ/modin/frame/test_info.py b/tests/integ/modin/frame/test_info.py index 2a096e76fdc..fbbf8dfe041 100644 --- a/tests/integ/modin/frame/test_info.py +++ b/tests/integ/modin/frame/test_info.py @@ -13,9 +13,7 @@ def _assert_info_lines_equal(modin_info: list[str], pandas_info: list[str]): # class is different - assert ( - modin_info[0] == "" - ) + assert modin_info[0] == "" assert pandas_info[0] == "" # index is different diff --git a/tests/integ/modin/test_classes.py b/tests/integ/modin/test_classes.py index c92bb85c531..6e6c2eda8eb 100644 --- a/tests/integ/modin/test_classes.py +++ b/tests/integ/modin/test_classes.py @@ -34,14 +34,14 @@ def test_class_names_constructors(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) s = pd.Series(index=[1, 2, 3], data=[3, 2, 1]) expect_type_check( s, pd.Series, - "snowflake.snowpark.modin.pandas.series.Series", + "modin.pandas.series.Series", ) @@ -63,7 +63,7 @@ def test_op(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) @@ -77,7 +77,7 @@ def test_native_conversion(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) # Snowpark pandas -> native pandas diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py index ce9e1caf328..a36298af251 100644 --- a/tests/integ/modin/test_telemetry.py +++ b/tests/integ/modin/test_telemetry.py @@ -110,7 +110,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name): df1_expected_api_calls = [ {"name": "TestClass.test_func"}, - {"name": "DataFrame.DataFrame.dropna", "argument": ["inplace"]}, + {"name": "DataFrame.dropna", "argument": ["inplace"]}, ] assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls @@ -121,7 +121,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name): assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls df2_expected_api_calls = df1_expected_api_calls + [ { - "name": "DataFrame.DataFrame.dropna", + "name": "DataFrame.dropna", }, ] assert df2._query_compiler.snowpark_pandas_api_calls == df2_expected_api_calls @@ -336,10 +336,7 @@ def test_telemetry_with_update_inplace(): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) df.insert(1, "newcol", [99, 99, 90]) assert len(df._query_compiler.snowpark_pandas_api_calls) == 1 - assert ( - df._query_compiler.snowpark_pandas_api_calls[0]["name"] - == "DataFrame.DataFrame.insert" - ) + assert df._query_compiler.snowpark_pandas_api_calls[0]["name"] == "DataFrame.insert" @sql_count_checker(query_count=1) @@ -403,8 +400,8 @@ def test_telemetry_getitem_setitem(): df["a"] = 0 df["b"] = 0 assert df._query_compiler.snowpark_pandas_api_calls == [ - {"name": "DataFrame.DataFrame.__setitem__"}, - {"name": "DataFrame.DataFrame.__setitem__"}, + {"name": "DataFrame.__setitem__"}, + {"name": "DataFrame.__setitem__"}, ] # Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction. s._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch() @@ -422,13 +419,17 @@ def test_telemetry_getitem_setitem(): @pytest.mark.parametrize( - "name, method, expected_query_count", + "name, expected_func_name, method, expected_query_count", [ - ["__repr__", lambda df: df.__repr__(), 1], - ["__iter__", lambda df: df.__iter__(), 0], + # __repr__ is an extension method, so the class name is shown only once. + ["__repr__", "DataFrame.__repr__", lambda df: df.__repr__(), 1], + # __iter__ was defined on the DataFrame class, so it is shown twice. + ["__iter__", "DataFrame.DataFrame.__iter__", lambda df: df.__iter__(), 0], ], ) -def test_telemetry_private_method(name, method, expected_query_count): +def test_telemetry_private_method( + name, expected_func_name, method, expected_query_count +): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) # Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction. df._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch() @@ -439,10 +440,10 @@ def test_telemetry_private_method(name, method, expected_query_count): # the telemetry log from the connector to validate data = _extract_snowpark_pandas_telemetry_log_data( - expected_func_name=f"DataFrame.DataFrame.{name}", + expected_func_name=expected_func_name, session=df._query_compiler._modin_frame.ordered_dataframe.session, ) - assert data["api_calls"] == [{"name": f"DataFrame.DataFrame.{name}"}] + assert data["api_calls"] == [{"name": expected_func_name}] @sql_count_checker(query_count=0) diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index 7c5e3a40bb0..d94c80b8d67 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -166,6 +166,7 @@ def test_overrides(self): # Test for pandas doc when function is not defined on module. assert pandas.read_table.__doc__ in pd.read_table.__doc__ + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_not_redefining_classes_modin_issue_7138(self): original_dataframe_class = pd.DataFrame _init_doc_module() From 0ee303348abf772fb161d1edf4f75728089da759 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Mon, 16 Sep 2024 11:01:57 -0700 Subject: [PATCH 10/22] [SNOW-1632895] Add derive_dependent_columns_with_duplication capability (#2272) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1632895 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. in order to handle nested select with column dependency (not handled by sql simplifier) complexity calculation, added a utility support for deriving all column depenencies with duplication, for example, col('a') + col('b') + 3*('a'), should return dependency ['a', 'b', 'a'] this provides both information about the columns it dependents on and also the number of times/ --- .../_internal/analyzer/binary_expression.py | 6 +- .../snowpark/_internal/analyzer/expression.py | 96 +++++++++++++- .../_internal/analyzer/grouping_set.py | 8 ++ .../_internal/analyzer/sort_expression.py | 6 +- .../_internal/analyzer/unary_expression.py | 6 +- .../_internal/analyzer/window_expression.py | 17 +++ .../unit/test_expression_dependent_columns.py | 125 ++++++++++++++++++ 7 files changed, 259 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 3ed969caada..22591f55e47 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -29,6 +30,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.left, self.right) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a2d21db4eb2..a7cb5fd97a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -35,6 +35,13 @@ def derive_dependent_columns( *expressions: "Optional[Expression]", ) -> Optional[AbstractSet[str]]: + """ + Given set of expressions, derive the set of columns that the expressions dependents on. + + Note, the returned dependent columns is a set without duplication. For example, given expression + concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has + occurred in the given expression twice. + """ result = set() for exp in expressions: if exp is not None: @@ -48,6 +55,23 @@ def derive_dependent_columns( return result +def derive_dependent_columns_with_duplication( + *expressions: "Optional[Expression]", +) -> List[str]: + """ + Given set of expressions, derive the list of columns that the expression dependents on. + + Note, the returned columns will have duplication if the column occurred more than once in + the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result + [col1, col1, col2], where col1 occurred twice in the result. + """ + result = [] + for exp in expressions: + if exp is not None: + result.extend(exp.dependent_column_names_with_duplication()) + return result + + class Expression: """Consider removing attributes, and adding properties and methods. A subclass of Expression may have no child, one child, or multiple children. @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY + def dependent_column_names_with_duplication(self) -> List[str]: + return [] + @property def pretty_name(self) -> str: """Returns a user-facing string representation of this expression's name. @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + def dependent_column_names_with_duplication(self) -> List[str]: + return list(COLUMN_DEPENDENCY_DOLLAR) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.expressions) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.columns, *self.values) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN @@ -212,6 +248,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + def dependent_column_names_with_duplication(self) -> List[str]: + return [self.name] + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + derive_dependent_columns_with_duplication(*self.expressions) + if self.expressions + else [] # we currently do not handle * dependency + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} @@ -278,6 +324,14 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + [] + if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL) + or (self._dependent_column_names is None) + else list(self._dependent_column_names) + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern @@ -400,6 +457,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -510,6 +579,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) + @property def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) @@ -549,13 +624,21 @@ def __init__( self.branches = branches self.else_value = else_value - def dependent_column_names(self) -> Optional[AbstractSet[str]]: + @property + def _child_expressions(self) -> List[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) if self.else_value is not None: exps.append(self.else_value) - return derive_dependent_columns(*exps) + + return exps + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self._child_expressions) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self._child_expressions) @property def plan_node_category(self) -> PlanNodeCategory: @@ -602,6 +685,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.col) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.exprs) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 84cd63fd87d..012940471d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.group_by_exprs) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + def dependent_column_names_with_duplication(self) -> List[str]: + flattened_args = [exp for sublist in self.args for exp in sublist] + return derive_dependent_columns_with_duplication(*flattened_args) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..82451245e4c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, List, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) @@ -55,3 +56,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5886e11069..1ae08e8fde2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -36,6 +37,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 69db3f265ce..4381c4a2e22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -71,6 +72,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.lower, self.upper) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + *self.partition_spec, *self.order_spec, self.frame_spec + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec @@ -138,6 +147,11 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + self.window_function, self.window_spec + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW @@ -171,6 +185,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.default) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index c31e5cc6290..c9b8a1ce38d 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -87,30 +87,37 @@ def test_expression(): a = Expression() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] b = Expression(child=UnresolvedAttribute("a")) assert b.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert b.dependent_column_names_with_duplication() == [] # root class Expression always returns empty dependency def test_literal(): a = Literal(5) assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] def test_attribute(): a = Attribute("A", IntegerType()) assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] def test_unresolved_attribute(): a = UnresolvedAttribute("A") assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] b = UnresolvedAttribute("a > 1", is_sql_text=True) assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] c = UnresolvedAttribute("$1 > 1", is_sql_text=True) assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert c.dependent_column_names_with_duplication() == ["$"] def test_case_when(): @@ -118,46 +125,85 @@ def test_case_when(): b = Column("b") z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e")) assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'} + # verify column '"A"', '"B"' occurred twice in the dependency columns + assert z._expression.dependent_column_names_with_duplication() == [ + '"A"', + '"B"', + '"C"', + '"A"', + '"B"', + '"D"', + '"E"', + ] def test_collate(): a = Collate(UnresolvedAttribute("a"), "spec") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_function_expression(): a = FunctionExpression("test_func", [UnresolvedAttribute(x) for x in "abcd"], False) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # expressions with duplicated dependent column + b = FunctionExpression( + "test_func", [UnresolvedAttribute(x) for x in "abcdad"], False + ) + assert b.dependent_column_names() == set("abcd") + assert b.dependent_column_names_with_duplication() == list("abcdad") def test_in_expression(): a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") def test_like(): a = Like(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = Like(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_list_agg(): a = ListAgg(UnresolvedAttribute("a"), ",", True) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_multiple_expression(): a = MultipleExpression([UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + a = MultipleExpression([UnresolvedAttribute(x) for x in "abcdbea"]) + assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("abcdbea") def test_reg_exp(): a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + b = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_scalar_subquery(): a = ScalarSubquery(None) assert a.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert a.dependent_column_names_with_duplication() == list(COLUMN_DEPENDENCY_DOLLAR) def test_snowflake_udf(): @@ -165,21 +211,42 @@ def test_snowflake_udf(): "udf_name", [UnresolvedAttribute(x) for x in "abcd"], IntegerType() ) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + b = SnowflakeUDF( + "udf_name", [UnresolvedAttribute(x) for x in "abcdfc"], IntegerType() + ) + assert b.dependent_column_names() == set("abcdf") + assert b.dependent_column_names_with_duplication() == list("abcdfc") def test_star(): a = Star([Attribute(x, IntegerType()) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + b = Star([]) + assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] def test_subfield_string(): a = SubfieldString(UnresolvedAttribute("a"), "field") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_within_group(): a = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") + + b = WithinGroup( + UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"] + ) + assert b.dependent_column_names() == set("abcde") + assert b.dependent_column_names_with_duplication() == list("eabcdea") @pytest.mark.parametrize( @@ -189,16 +256,19 @@ def test_within_group(): def test_unary_expression(expression_class): a = expression_class(child=UnresolvedAttribute("a")) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_alias(): a = Alias(child=Add(UnresolvedAttribute("a"), UnresolvedAttribute("b")), name="c") assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_cast(): a = Cast(UnresolvedAttribute("a"), IntegerType()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] @pytest.mark.parametrize( @@ -234,6 +304,19 @@ def test_binary_expression(expression_class): assert b.dependent_column_names() == {"B"} assert binary_expression.dependent_column_names() == {"A", "B"} + assert a.dependent_column_names_with_duplication() == ["A"] + assert b.dependent_column_names_with_duplication() == ["B"] + assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"] + + # hierarchical expressions with duplication + hierarchical_binary_expression = expression_class(expression_class(a, b), b) + assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"} + assert hierarchical_binary_expression.dependent_column_names_with_duplication() == [ + "A", + "B", + "B", + ] + @pytest.mark.parametrize( "expression_class", @@ -253,6 +336,18 @@ def test_grouping_set(expression_class): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] + + # with duplication + b = expression_class( + [ + UnresolvedAttribute("a"), + UnresolvedAttribute("a"), + UnresolvedAttribute("c"), + ] + ) + assert b.dependent_column_names() == {"a", "c"} + assert b.dependent_column_names_with_duplication() == ["a", "a", "c"] def test_grouping_sets_expression(): @@ -263,11 +358,13 @@ def test_grouping_sets_expression(): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] def test_sort_order(): a = SortOrder(UnresolvedAttribute("a"), Ascending()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_specified_window_frame(): @@ -275,12 +372,21 @@ def test_specified_window_frame(): RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("b") ) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("a") + ) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] @pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead]) def test_rank_related_function_expression(expression_class): a = expression_class(UnresolvedAttribute("a"), 1, UnresolvedAttribute("b"), False) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_window_spec_definition(): @@ -295,6 +401,7 @@ def test_window_spec_definition(): ), ) assert a.dependent_column_names() == set("abcdef") + assert a.dependent_column_names_with_duplication() == list("abcdef") def test_window_expression(): @@ -310,6 +417,23 @@ def test_window_expression(): ) a = WindowExpression(UnresolvedAttribute("x"), window_spec_definition) assert a.dependent_column_names() == set("abcdefx") + assert a.dependent_column_names_with_duplication() == list("xabcdef") + + +def test_window_expression_with_duplication_columns(): + window_spec_definition = WindowSpecDefinition( + [UnresolvedAttribute("a"), UnresolvedAttribute("b")], + [ + SortOrder(UnresolvedAttribute("c"), Ascending()), + SortOrder(UnresolvedAttribute("a"), Ascending()), + ], + SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("e"), UnresolvedAttribute("f") + ), + ) + a = WindowExpression(UnresolvedAttribute("e"), window_spec_definition) + assert a.dependent_column_names() == set("abcef") + assert a.dependent_column_names_with_duplication() == list("eabcaef") @pytest.mark.parametrize( @@ -325,3 +449,4 @@ def test_window_expression(): def test_other_window_expressions(expression_class): a = expression_class() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] From e93cd6821d8accf69cbe95c5bb171838de844112 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Mon, 16 Sep 2024 18:09:18 -0700 Subject: [PATCH 11/22] SNOW-1661142 Fix index name behavior (#2274) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1661142 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. Fixed a bug where updating an index's name updates the parent's index name when it is not supposed to. This is done by verifying that the query_compiler recorded during the index's creation matches that of its parent object when the parent object must be updated. ```py >>> df = pd.DataFrame( ... { ... "A": [0, 1, 2, 3, 4, 4], ... "B": ['a', 'b', 'c', 'd', 'e', 'f'], ... }, ... index = pd.Index([1, 2, 3, 4, 5, 6], name = "test"), ... ) >>> index = df.index >>> df A B test 1 0 a 2 1 b 3 2 c 4 3 d 5 4 e 6 4 f >>> index.name = "test2" >>> >>> df A B test2 1 0 a 2 1 b 3 2 c 4 3 d 5 4 e 6 4 f >>> df.dropna(inplace=True) >>> index.name = "test3" >>> df A B test2 # <--- name should not update 1 0 a 2 1 b 3 2 c 4 3 d 5 4 e 6 4 f ``` For the full discussion, see thread: https://docs.google.com/document/d/1vdllzNgeUHMiffFNpm9SD1HOYUk8lkMVp14HQDoqr7s/edit?disco=AAABVbKjFJ0 --- CHANGELOG.md | 4 ++ .../snowpark/modin/plugin/extensions/index.py | 61 +++++++++++++---- .../index/test_datetime_index_methods.py | 4 +- tests/integ/modin/index/test_index_methods.py | 4 +- tests/integ/modin/index/test_name.py | 66 +++++++++++++++++++ 5 files changed, 122 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bd719dcb8c..7048a3b728a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. - Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. +#### Bug Fixes + +- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. + ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 643f6f5038e..b25bb481dc0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -71,6 +71,35 @@ } +class IndexParent: + def __init__(self, parent: DataFrame | Series) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + + Parameters + ---------- + parent : DataFrame or Series + The parent object that the Index is a part of. + """ + assert isinstance(parent, (DataFrame, Series)) + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the query compiler associated with the parent is + different from the original query compiler recorded, i.e., an inplace update has been applied to the parent. + """ + if self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,7 +164,7 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. + # `_parent` keeps track of the parent object that this Index is a part of. index._parent = None return index @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any: ErrorMessage.not_implemented(f"Index.{key} is not yet implemented") raise err + def _set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = IndexParent(parent) + def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) @@ -408,12 +448,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - @property def values(self) -> ArrayLike: """ @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None: if not is_hashable(value): raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) + # Update the name of the parent's index only if an inplace update is performed on + # the parent object, i.e., the parent's current query compiler matches the originally + # recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None: if isinstance(values, Index): values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) + # Update the name of the parent's index only if the parent's current query compiler + # matches the recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 793485f97d6..98d1a041c3b 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -142,13 +142,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame({"A": [1]}, index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 8d0434915ac..6b33eb89889 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -393,13 +393,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py index b916110f386..f915598c5f6 100644 --- a/tests/integ/modin/index/test_name.py +++ b/tests/integ/modin/index/test_name.py @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index(): ), inplace=True, ) + + +@sql_count_checker(query_count=1) +def test_index_names_replace_behavior(): + """ + Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change the names. + snow_index.name = "test2" + native_index.name = "test2" + + # Compare the names. + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the query compiler the DataFrame is referring to, change the names. + snow_df.dropna(inplace=True) + native_df.dropna(inplace=True) + snow_index.name = "test3" + native_index.name = "test3" + + # Compare the names. Changing the index name should not change the DataFrame's index name. + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test2" + + +@sql_count_checker(query_count=1) +def test_index_names_multiple_renames(): + """ + Check that the index name of a DataFrame can be renamed any number of times. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change and compare the names. + snow_index.name = "test2" + native_index.name = "test2" + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the names again and compare. + snow_index.name = "test3" + native_index.name = "test3" + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test3" From 3f605072e9204f9c67068f45668cbaf1960613b3 Mon Sep 17 00:00:00 2001 From: Mahesh Vashishtha Date: Tue, 17 Sep 2024 09:51:07 -0700 Subject: [PATCH 12/22] SNOW-1664064: Suppress SettingWithCopyWarning for Timedelta columns. (#2298) Fixes SNOW-1664064 --------- Signed-off-by: sfc-gh-mvashishtha --- CHANGELOG.md | 1 + .../snowpark/modin/plugin/_internal/utils.py | 12 +++++++++--- tests/integ/modin/types/test_timedelta.py | 9 +++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7048a3b728a..daedfe34659 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ #### Bug Fixes - Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. +- Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns. ## 1.22.1 (2024-09-11) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 70025fd8b0a..34a3376fcc1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -1519,9 +1519,15 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta: downcast_pandas_df.columns, cached_snowpark_pandas_types ): if snowpark_pandas_type is not None and snowpark_pandas_type == timedelta_t: - downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( - convert_str_to_timedelta - ) + # By default, pandas warns, "A value is trying to be set on a + # copy of a slice from a DataFrame" here because we are + # assigning a column to downcast_pandas_df, which is a copy of + # a slice of pandas_df. We don't care what happens to pandas_df, + # so the warning isn't useful to us. + with native_pd.option_context("mode.chained_assignment", None): + downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( + convert_str_to_timedelta + ) # Step 7. postprocessing for return types if index_only: diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py index 4c72df42bba..d28362374ce 100644 --- a/tests/integ/modin/types/test_timedelta.py +++ b/tests/integ/modin/types/test_timedelta.py @@ -2,10 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import datetime +import warnings import modin.pandas as pd import pandas as native_pd import pytest +from pandas.errors import SettingWithCopyWarning from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( @@ -107,3 +109,10 @@ def test_timedelta_not_supported(): match="SnowflakeQueryCompiler::groupby_groups is not yet implemented for Timedelta Type", ): df.groupby("a").groups() + + +@sql_count_checker(query_count=1) +def test_aggregation_does_not_print_internal_warning_SNOW_1664064(): + with warnings.catch_warnings(): + warnings.simplefilter(category=SettingWithCopyWarning, action="error") + pd.Series(pd.Timedelta(1)).max() From 8414933662b908eac67a40dd0785486f10e4ee2a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 11:22:27 -0700 Subject: [PATCH 13/22] SNOW-1418523: Remove session variable that can be local (#2279) --- src/snowflake/snowpark/session.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index e66155e01ea..074dfb54b7f 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -601,7 +601,6 @@ def __init__( ) self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) - self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) _logger.info("Snowpark Session information: %s", self._session_info) @@ -1710,8 +1709,8 @@ def _upload_unsupported_packages( try: # Setup a temporary directory and target folder where pip install will take place. - self._tmpdir_handler = tempfile.TemporaryDirectory() - tmpdir = self._tmpdir_handler.name + tmpdir_handler = tempfile.TemporaryDirectory() + tmpdir = tmpdir_handler.name target = os.path.join(tmpdir, "unsupported_packages") if not os.path.exists(target): os.makedirs(target) @@ -1796,9 +1795,7 @@ def _upload_unsupported_packages( for requirement in supported_dependencies + new_dependencies ] ) - metadata_local_path = os.path.join( - self._tmpdir_handler.name, metadata_file - ) + metadata_local_path = os.path.join(tmpdir_handler.name, metadata_file) with open(metadata_local_path, "w") as file: for key, value in metadata.items(): file.write(f"{key},{value}\n") @@ -1834,9 +1831,8 @@ def _upload_unsupported_packages( f"-third-party-packages-from-anaconda-in-a-udf." ) finally: - if self._tmpdir_handler: - self._tmpdir_handler.cleanup() - self._tmpdir_handler = None + if tmpdir_handler: + tmpdir_handler.cleanup() return supported_dependencies + new_dependencies From 6554f8209a0026c34c0d0a389a07d5817a605d38 Mon Sep 17 00:00:00 2001 From: Naren Krishna Date: Tue, 17 Sep 2024 12:50:40 -0700 Subject: [PATCH 14/22] SNOW-1662275: Add support for parameters `left_index` and `right_index` for `pd.merge_asof` (#2304) SNOW-1662275 This PR adds tests and removes `NotImplementedError` relevant to `left_index` and `right_index` parameters supplied to `pd.merge_asof`. Support was added in https://github.com/snowflakedb/snowpark-python/pull/2095 Signed-off-by: Naren Krishna --- CHANGELOG.md | 2 +- .../modin/supported/general_supported.rst | 3 +- .../compiler/snowflake_query_compiler.py | 4 +- tests/integ/modin/test_merge_asof.py | 43 ++++++++++++++----- 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index daedfe34659..829c027c527 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ - Added support for `TimedeltaIndex.mean` method. - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. -- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. +- Added support for `by`, `left_by`, `right_by`, `left_index`, and `right_index` for `pd.merge_asof`. #### Bug Fixes diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 95d9610202b..b3a71f023a9 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -38,8 +38,7 @@ Data manipulations +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. | -| | | , ``suffixes``, ``tolerance`` | | +| ``merge_asof`` | P | ``suffixes``, ``tolerance`` | ``N`` if param ``direction`` is ``nearest`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge_ordered`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index e971b15b6d6..a3981379aaf 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -7381,10 +7381,10 @@ def merge_asof( SnowflakeQueryCompiler """ # TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation - if left_index or right_index or tolerance or suffixes != ("_x", "_y"): + if tolerance or suffixes != ("_x", "_y"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method does not currently support parameters " - + "'left_index', 'right_index', 'suffixes', or 'tolerance'" + + "'suffixes', or 'tolerance'" ) if direction not in ("backward", "forward"): ErrorMessage.not_implemented( diff --git a/tests/integ/modin/test_merge_asof.py b/tests/integ/modin/test_merge_asof.py index 51dda7889e7..5aab91fc9cb 100644 --- a/tests/integ/modin/test_merge_asof.py +++ b/tests/integ/modin/test_merge_asof.py @@ -231,6 +231,37 @@ def test_merge_asof_left_right_on( assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) +@allow_exact_matches +@direction +@sql_count_checker(query_count=1, join_count=1) +def test_merge_asof_left_right_index(allow_exact_matches, direction): + native_left = native_pd.DataFrame({"left_val": ["a", "b", "c"]}, index=[1, 5, 10]) + native_right = native_pd.DataFrame( + {"right_val": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7] + ) + + snow_left = pd.DataFrame(native_left) + snow_right = pd.DataFrame(native_right) + + native_output = native_pd.merge_asof( + native_left, + native_right, + left_index=True, + right_index=True, + direction=direction, + allow_exact_matches=allow_exact_matches, + ) + snow_output = pd.merge_asof( + snow_left, + snow_right, + left_index=True, + right_index=True, + direction=direction, + allow_exact_matches=allow_exact_matches, + ) + assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) + + @pytest.mark.parametrize("by", ["ticker", ["ticker"]]) @sql_count_checker(query_count=1, join_count=1) def test_merge_asof_by(left_right_timestamp_data, by): @@ -399,15 +430,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - + "'left_index', 'right_index', 'suffixes', or 'tolerance'" - ), - ): - pd.merge_asof(left_snow_df, right_snow_df, left_index=True, right_index=True) - with pytest.raises( - NotImplementedError, - match=( - "Snowpark pandas merge_asof method does not currently support parameters " - + "'left_index', 'right_index', 'suffixes', or 'tolerance'" + + "'suffixes', or 'tolerance'" ), ): pd.merge_asof( @@ -420,7 +443,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - + "'left_index', 'right_index', 'suffixes', or 'tolerance'" + + "'suffixes', or 'tolerance'" ), ): pd.merge_asof( From 1c83ef232d297d366b5ae468e07760af4f0c352e Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:02:28 -0700 Subject: [PATCH 15/22] use _package_lock to protect Session._packages --- src/snowflake/snowpark/_internal/udf_utils.py | 2 +- src/snowflake/snowpark/session.py | 121 +++++++++--------- 2 files changed, 63 insertions(+), 60 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 07635c8de8a..bf4ce0af9b8 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -981,7 +981,7 @@ def add_snowpark_package_to_sproc_packages( if session is None: packages = [this_package] else: - with session._lock: + with session._package_lock: session_packages = session._packages.copy() if package_name not in session_packages: packages = list(session_packages.values()) + [this_package] diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a233ce26683..0cda2dc5185 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -502,6 +502,11 @@ def __init__( self._conn = conn self._thread_store = threading.local() self._lock = threading.RLock() + + # this lock is used to protect _packages. We use introduce a new lock because add_packages + # launches a query to snowflake to get all version of packages available in snowflake. This + # query can be slow and prevent other threads from moving on waiting for _lock. + self._package_lock = threading.RLock() self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} self._packages: Dict[str, str] = {} @@ -1116,7 +1121,7 @@ def get_packages(self) -> Dict[str, str]: The key of this ``dict`` is the package name and the value of this ``dict`` is the corresponding requirement specifier. """ - with self._lock: + with self._package_lock: return self._packages.copy() def add_packages( @@ -1208,7 +1213,7 @@ def remove_package(self, package: str) -> None: 0 """ package_name = pkg_resources.Requirement.parse(package).key - with self._lock: + with self._package_lock: if package_name in self._packages: self._packages.pop(package_name) else: @@ -1218,7 +1223,7 @@ def clear_packages(self) -> None: """ Clears all third-party packages of a user-defined function (UDF). """ - with self._lock: + with self._package_lock: self._packages.clear() def add_requirements(self, file_path: str) -> None: @@ -1567,25 +1572,26 @@ def _resolve_packages( if isinstance(self._conn, MockServerConnection): # in local testing we don't resolve the packages, we just return what is added errors = [] - with self._lock: - result_dict = self._packages.copy() - for pkg_name, _, pkg_req in package_dict.values(): - if pkg_name in result_dict and str(pkg_req) != result_dict[pkg_name]: - errors.append( - ValueError( - f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} " - "is already added." + with self._package_lock: + result_dict = self._packages + for pkg_name, _, pkg_req in package_dict.values(): + if ( + pkg_name in result_dict + and str(pkg_req) != result_dict[pkg_name] + ): + errors.append( + ValueError( + f"Cannot add package '{str(pkg_req)}' because {result_dict[pkg_name]} " + "is already added." + ) ) - ) - else: - result_dict[pkg_name] = str(pkg_req) - if len(errors) == 1: - raise errors[0] - elif len(errors) > 0: - raise RuntimeError(errors) - - with self._lock: - self._packages.update(result_dict) + else: + result_dict[pkg_name] = str(pkg_req) + if len(errors) == 1: + raise errors[0] + elif len(errors) > 0: + raise RuntimeError(errors) + return list(result_dict.values()) package_table = "information_schema.packages" @@ -1600,50 +1606,47 @@ def _resolve_packages( # 'python-dateutil': 'python-dateutil==2.8.2'} # Add to packages dictionary. Make a copy of existing packages # dictionary to avoid modifying it during intermediate steps. - with self._lock: + with self._package_lock: result_dict = ( - existing_packages_dict.copy() - if existing_packages_dict is not None - else {} + existing_packages_dict if existing_packages_dict is not None else {} ) - # Retrieve list of dependencies that need to be added - dependency_packages = self._get_dependency_packages( - package_dict, - validate_package, - package_table, - result_dict, - statement_params=statement_params, - ) - - # Add dependency packages - for package in dependency_packages: - name = package.name - version = package.specs[0][1] if package.specs else None - - if name in result_dict: - if version is not None: - added_package_has_version = "==" in result_dict[name] - if added_package_has_version and result_dict[name] != str(package): - raise ValueError( - f"Cannot add dependency package '{name}=={version}' " - f"because {result_dict[name]} is already added." - ) + # Retrieve list of dependencies that need to be added + dependency_packages = self._get_dependency_packages( + package_dict, + validate_package, + package_table, + result_dict, + statement_params=statement_params, + ) + + # Add dependency packages + for package in dependency_packages: + name = package.name + version = package.specs[0][1] if package.specs else None + + if name in result_dict: + if version is not None: + added_package_has_version = "==" in result_dict[name] + if added_package_has_version and result_dict[name] != str( + package + ): + raise ValueError( + f"Cannot add dependency package '{name}=={version}' " + f"because {result_dict[name]} is already added." + ) + result_dict[name] = str(package) + else: result_dict[name] = str(package) - else: - result_dict[name] = str(package) - # Always include cloudpickle - extra_modules = [cloudpickle] - if include_pandas: - extra_modules.append("pandas") + # Always include cloudpickle + extra_modules = [cloudpickle] + if include_pandas: + extra_modules.append("pandas") - with self._lock: - if existing_packages_dict is not None: - existing_packages_dict.update(result_dict) - return list(result_dict.values()) + self._get_req_identifiers_list( - extra_modules, result_dict - ) + return list(result_dict.values()) + self._get_req_identifiers_list( + extra_modules, result_dict + ) def _upload_unsupported_packages( self, From 6a3a77bd04f1372257d8ec31f7fbdc9de74f21be Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Tue, 17 Sep 2024 14:09:55 -0700 Subject: [PATCH 16/22] [SNOW-1320248]: Fix `inplace=True` on Series objects derived from Series. (#2307) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1320248 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. This PR fixes the inplace argument for Series functions where the Series is derived from another Series object (e.g. series = series.iloc[:4]; series.fillna(14, inplace=True)) --------- Co-authored-by: Hazem Elmeleegy --- CHANGELOG.md | 2 +- .../plugin/extensions/series_overrides.py | 19 +++++++++++++++++++ tests/integ/modin/series/test_fillna.py | 16 ++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 829c027c527..90f2003251d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ - Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. - Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns. - +- Fixed `inplace` argument for `Series` objects derived from other `Series` objects. ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index b104c223e26..625e5b8032a 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -392,6 +392,25 @@ def __init__( self.name = name +@register_series_accessor("_update_inplace") +def _update_inplace(self, new_query_compiler) -> None: + """ + Update the current Series in-place using `new_query_compiler`. + + Parameters + ---------- + new_query_compiler : BaseQueryCompiler + QueryCompiler to use to manage the data. + """ + super(Series, self)._update_inplace(new_query_compiler=new_query_compiler) + # Propagate changes back to parent so that column in dataframe had the same contents + if self._parent is not None: + if self._parent_axis == 1 and isinstance(self._parent, DataFrame): + self._parent[self.name] = self + else: + self._parent.loc[self.index] = self + + # Since Snowpark pandas leaves all data on the warehouse, memory_usage's report of local memory # usage isn't meaningful and is set to always return 0. @_inherit_docstrings(native_pd.Series.memory_usage, apilink="pandas.Series") diff --git a/tests/integ/modin/series/test_fillna.py b/tests/integ/modin/series/test_fillna.py index 9371cd0dcd1..80997070b92 100644 --- a/tests/integ/modin/series/test_fillna.py +++ b/tests/integ/modin/series/test_fillna.py @@ -3,6 +3,8 @@ # +import string + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -201,3 +203,17 @@ def inplace_fillna(df): native_pd.DataFrame([[1, 2, 3], [4, None, 6]], columns=list("ABC")), inplace_fillna, ) + + +@pytest.mark.parametrize("index", [list(range(8)), list(string.ascii_lowercase[:8])]) +@sql_count_checker(query_count=1, join_count=4) +def test_inplace_fillna_from_series(index): + def inplace_fillna(series): + series.iloc[:4].fillna(14, inplace=True) + return series + + eval_snowpark_pandas_result( + pd.Series([np.nan, 1, 2, 3, 4, 5, 6, 7], index=index), + native_pd.Series([np.nan, 1, 2, 3, 4, 5, 6, 7], index=index), + inplace_fillna, + ) From a6497610be5072357c54b5c64bb8df95c65ef165 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:14:21 -0700 Subject: [PATCH 17/22] undo refactor --- src/snowflake/snowpark/_internal/udf_utils.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index bf4ce0af9b8..58d698556b3 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -982,9 +982,8 @@ def add_snowpark_package_to_sproc_packages( packages = [this_package] else: with session._package_lock: - session_packages = session._packages.copy() - if package_name not in session_packages: - packages = list(session_packages.values()) + [this_package] + if package_name not in session._packages: + packages = list(session._packages.values()) + [this_package] else: package_names = [p if isinstance(p, str) else p.__name__ for p in packages] if not any(p.startswith(package_name) for p in package_names): @@ -1076,20 +1075,19 @@ def resolve_imports_and_packages( ) ) - all_urls = [] if session is not None: import_only_stage = ( unwrap_stage_location_single_quote(stage_location) if stage_location else session.get_session_stage(statement_params=statement_params) ) - upload_and_import_stage = ( import_only_stage if is_permanent else session.get_session_stage(statement_params=statement_params) ) + if session: if imports: udf_level_imports = {} for udf_import in imports: @@ -1117,15 +1115,22 @@ def resolve_imports_and_packages( upload_and_import_stage, statement_params=statement_params, ) + else: + all_urls = [] + else: + all_urls = [] dest_prefix = get_udf_upload_prefix(udf_name) # Upload closure to stage if it is beyond inline closure size limit handler = inline_code = upload_file_stage_location = None - # As cloudpickle is being used, we cannot allow a custom runtime - custom_python_runtime_version_allowed = not isinstance(func, Callable) + custom_python_runtime_version_allowed = False if session is not None: if isinstance(func, Callable): + custom_python_runtime_version_allowed = ( + False # As cloudpickle is being used, we cannot allow a custom runtime + ) + # generate a random name for udf py file # and we compress it first then upload it udf_file_name_base = f"udf_py_{random_number()}" @@ -1170,6 +1175,7 @@ def resolve_imports_and_packages( upload_file_stage_location = None handler = _DEFAULT_HANDLER_NAME else: + custom_python_runtime_version_allowed = True udf_file_name = os.path.basename(func[0]) # for a compressed file, it might have multiple extensions # and we should remove all extensions @@ -1194,6 +1200,11 @@ def resolve_imports_and_packages( skip_upload_on_content_match=skip_upload_on_content_match, ) all_urls.append(upload_file_stage_location) + else: + if isinstance(func, Callable): + custom_python_runtime_version_allowed = False + else: + custom_python_runtime_version_allowed = True # build imports and packages string all_imports = ",".join( From f03d6186f84b4f5b593d1920d820e218347e3e3a Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:15:50 -0700 Subject: [PATCH 18/22] undo refactor --- src/snowflake/snowpark/_internal/udf_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 58d698556b3..25921fff821 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -1081,6 +1081,7 @@ def resolve_imports_and_packages( if stage_location else session.get_session_stage(statement_params=statement_params) ) + upload_and_import_stage = ( import_only_stage if is_permanent From 5f398d5f00edb42ca3f19d656f656fdc91a9c99c Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:33:07 -0700 Subject: [PATCH 19/22] fix test --- tests/unit/test_udf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index c23755c14a3..09e389d0c24 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -250,7 +250,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session(): "random_package_one": "random_package_one", "random_package_two": "random_package_two", } - fake_session._lock = threading.RLock() + fake_session._package_lock = threading.RLock() result = add_snowpark_package_to_sproc_packages(session=fake_session, packages=None) major, minor, patch = VERSION From 380708778a6678f9f2bdbbf9a6e95a3b5074bf19 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 14:36:20 -0700 Subject: [PATCH 20/22] fix test --- tests/unit/test_server_connection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_server_connection.py b/tests/unit/test_server_connection.py index 72ccb6f6c42..cf10dc9d29a 100644 --- a/tests/unit/test_server_connection.py +++ b/tests/unit/test_server_connection.py @@ -119,6 +119,7 @@ def test_get_result_set_exception(mock_server_connection): fake_session._last_canceled_id = 100 fake_session._conn = mock_server_connection fake_session._cte_optimization_enabled = False + fake_session._query_compilation_stage_enabled = False fake_plan = SnowflakePlan( queries=[Query("fake query 1"), Query("fake query 2")], schema_query="fake schema query", From df3263c20c50beabe3f861c3db502ff2c8830c34 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Thu, 12 Sep 2024 15:06:48 -0700 Subject: [PATCH 21/22] add file IO tests --- tests/integ/conftest.py | 23 ++++++- .../integ/scala/test_file_operation_suite.py | 17 +---- tests/integ/test_multithreading.py | 69 ++++++++++++++++++- 3 files changed, 91 insertions(+), 18 deletions(-) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec619605e66..319abb137f4 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -13,7 +13,13 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.mock._connection import MockServerConnection from tests.parameters import CONNECTION_PARAMETERS -from tests.utils import TEST_SCHEMA, Utils, running_on_jenkins, running_on_public_ci +from tests.utils import ( + TEST_SCHEMA, + TestFiles, + Utils, + running_on_jenkins, + running_on_public_ci, +) def print_help() -> None: @@ -235,3 +241,18 @@ def temp_schema(connection, session, local_testing_mode) -> None: ) yield temp_schema_name cursor.execute(f"DROP SCHEMA IF EXISTS {temp_schema_name}") + + +@pytest.fixture(scope="module") +def temp_stage(session, resources_path, local_testing_mode): + tmp_stage_name = Utils.random_stage_name() + test_files = TestFiles(resources_path) + + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + session, tmp_stage_name, test_files.test_file_parquet, compress=False + ) + yield tmp_stage_name + if not local_testing_mode: + Utils.drop_stage(session, tmp_stage_name) diff --git a/tests/integ/scala/test_file_operation_suite.py b/tests/integ/scala/test_file_operation_suite.py index 2dc424dde09..82a1722a729 100644 --- a/tests/integ/scala/test_file_operation_suite.py +++ b/tests/integ/scala/test_file_operation_suite.py @@ -14,7 +14,7 @@ SnowparkSQLException, SnowparkUploadFileException, ) -from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, TestFiles, Utils +from tests.utils import IS_IN_STORED_PROC, IS_WINDOWS, Utils def random_alphanumeric_name(): @@ -74,21 +74,6 @@ def path4(temp_source_directory): yield filename -@pytest.fixture(scope="module") -def temp_stage(session, resources_path, local_testing_mode): - tmp_stage_name = Utils.random_stage_name() - test_files = TestFiles(resources_path) - - if not local_testing_mode: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - Utils.upload_to_stage( - session, tmp_stage_name, test_files.test_file_parquet, compress=False - ) - yield tmp_stage_name - if not local_testing_mode: - Utils.drop_stage(session, tmp_stage_name) - - def test_put_with_one_file( session, temp_stage, path1, path2, path3, local_testing_mode ): diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 164a4b7b590..10fcc6ef70d 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -2,6 +2,9 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import hashlib +import os +import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch @@ -9,7 +12,7 @@ from snowflake.snowpark.functions import lit from snowflake.snowpark.row import Row -from tests.utils import IS_IN_STORED_PROC, Utils +from tests.utils import IS_IN_STORED_PROC, TestFiles, Utils def test_concurrent_select_queries(session): @@ -122,3 +125,67 @@ def test_action_ids_are_unique(session): action_ids.add(future.result()) assert len(action_ids) == 10 + + +@pytest.mark.parametrize("use_stream", [True, False]) +def test_file_io(session, resources_path, temp_stage, use_stream): + stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" + stage_with_prefix = f"@{temp_stage}/{stage_prefix}/" + test_files = TestFiles(resources_path) + + resources_files = [ + test_files.test_file_csv, + test_files.test_file2_csv, + test_files.test_file_json, + test_files.test_file_csv_header, + test_files.test_file_csv_colon, + test_files.test_file_csv_quotes, + test_files.test_file_csv_special_format, + test_files.test_file_json_special_format, + test_files.test_file_csv_quotes_special, + test_files.test_concat_file1_csv, + test_files.test_concat_file2_csv, + ] + + def get_file_hash(fd): + return hashlib.md5(fd.read()).hexdigest() + + def put_and_get_file(upload_file_path, download_dir): + if use_stream: + with open(upload_file_path, "rb") as fd: + results = session.file.put_stream( + fd, stage_with_prefix, auto_compress=False, overwrite=False + ) + else: + results = session.file.put( + upload_file_path, + stage_with_prefix, + auto_compress=False, + overwrite=False, + ) + # assert file is uploaded successfully + assert len(results) == 1 + assert results[0].status == "UPLOADED" + + stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}" + if use_stream: + fd = session.file.get_stream(stage_file_name, download_dir) + with open(upload_file_path, "rb") as upload_fd: + assert get_file_hash(upload_fd) == get_file_hash(fd) + + else: + results = session.file.get(stage_file_name, download_dir) + # assert file is downloaded successfully + assert len(results) == 1 + assert results[0].status == "DOWNLOADED" + download_file_path = results[0].file + # assert two files are identical + with open(upload_file_path, "rb") as upload_fd, open( + download_file_path, "rb" + ) as download_fd: + assert get_file_hash(upload_fd) == get_file_hash(download_fd) + + with tempfile.TemporaryDirectory() as download_dir: + with ThreadPoolExecutor(max_workers=10) as executor: + for file_path in resources_files: + executor.submit(put_and_get_file, file_path, download_dir) From a737f33a1abfcdc5c1457703c5a46f98307fc9f2 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Tue, 17 Sep 2024 15:24:27 -0700 Subject: [PATCH 22/22] fix test --- tests/unit/compiler/test_large_query_breakdown.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/compiler/test_large_query_breakdown.py b/tests/unit/compiler/test_large_query_breakdown.py index d040ca25f49..5c9e140694f 100644 --- a/tests/unit/compiler/test_large_query_breakdown.py +++ b/tests/unit/compiler/test_large_query_breakdown.py @@ -105,7 +105,12 @@ ], ) def test_pipeline_breaker_node(mock_session, mock_analyzer, node_generator, expected): - large_query_breakdown = LargeQueryBreakdown(mock_session, mock_analyzer, []) + large_query_breakdown = LargeQueryBreakdown( + mock_session, + mock_analyzer, + [], + mock_session.large_query_breakdown_complexity_bounds, + ) node = node_generator(mock_analyzer) assert (