diff --git a/CHANGELOG.md b/CHANGELOG.md index fea42391259..daedfe34659 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,36 @@ # Release History -## 1.22.0 (TBD) +## 1.23.0 (TBD) + +### Snowpark pandas API Updates + +#### Improvements + +- Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type. + +#### New Features + +- Added support for `TimedeltaIndex.mean` method. +- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. +- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. + +#### Bug Fixes + +- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. +- Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns. + + +## 1.22.1 (2024-09-11) +This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. + + +## 1.22.0 (2024-09-10) ### Snowpark Python API Updates ### New Features -- Added following new functions in `snowflake.snowpark.functions`: +- Added the following new functions in `snowflake.snowpark.functions`: - `array_remove` - `ln` @@ -46,14 +70,14 @@ - Fixed a bug in `session.read.csv` that caused an error when setting `PARSE_HEADER = True` in an externally defined file format. - Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries. - Fixed a bug in `session.get_session_stage` that referenced a non-existing stage after switching database or schema. -- Fixed a bug where calling `DataFrame.to_snowpark_pandas_dataframe` without explicitly initializing the Snowpark pandas plugin caused an error. +- Fixed a bug where calling `DataFrame.to_snowpark_pandas` without explicitly initializing the Snowpark pandas plugin caused an error. - Fixed a bug where using the `explode` function in dynamic table creation caused a SQL compilation error due to improper boolean type casting on the `outer` parameter. ### Snowpark Local Testing Updates #### New Features -- Added support for type coercion when passing columns as input to udf calls +- Added support for type coercion when passing columns as input to UDF calls. - Added support for `Index.identical`. #### Bug Fixes @@ -105,6 +129,9 @@ - Added support for creating a `DatetimeIndex` from an `Index` of numeric or string type. - Added support for string indexing with `Timedelta` objects. - Added support for `Series.dt.total_seconds` method. +- Added support for `DataFrame.apply(axis=0)`. +- Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`. +- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`. #### Improvements @@ -113,9 +140,11 @@ - Improved `pd.to_datetime` to handle all local input cases. - Create a lazy index from another lazy index without pulling data to client. - Raised `NotImplementedError` for Index bitwise operators. -- Display a clearer error message when `Index.names` is set to a non-like-like object. +- Display a more clear error message when `Index.names` is set to a non-like-like object. - Raise a warning whenever MultiIndex values are pulled in locally. - Improve warning message for `pd.read_snowflake` include the creation reason when temp table creation is triggered. +- Improve performance for `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index` by avoiding checks require eager evaluation. As a consequence, when the new index that does not match the current `Series`/`DataFrame` object length, a `ValueError` is no longer raised. Instead, when the `Series`/`DataFrame` object is longer than the provided index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. Otherwise, the extra values in the provided index are ignored. +- Properly raise `NotImplementedError` when ambiguous/nonexistent are non-string in `ceil`/`floor`/`round`. #### Bug Fixes @@ -126,10 +155,6 @@ - Fixed a bug where `Series.reindex` and `DataFrame.reindex` did not update the result index's name correctly. - Fixed a bug where `Series.take` did not error when `axis=1` was specified. -#### Behavior Change - -- When calling `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index`, with a new index that does not match the current length of the `Series`/`DataFrame` object, a `ValueError` is no longer raised. When the `Series`/`DataFrame` object is longer than the new index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. When the `Series`/`DataFrame` object is shorter than the new index, the extra values in the new index are ignored—`Series` and `DataFrame` stay the same length `n`, and use only the first `n` values of the new index. - ## 1.21.1 (2024-09-05) diff --git a/docs/source/modin/series.rst b/docs/source/modin/series.rst index 188bdab344a..4cb8a238b0f 100644 --- a/docs/source/modin/series.rst +++ b/docs/source/modin/series.rst @@ -279,6 +279,8 @@ Series Series.dt.seconds Series.dt.microseconds Series.dt.nanoseconds + Series.dt.tz_convert + Series.dt.tz_localize .. rubric:: String accessor methods diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 6bb214e3bd6..54858063e54 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -84,7 +84,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``any`` | P | | ``N`` for non-integer/boolean types | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``apply`` | P | | ``N`` if ``axis == 0`` or ``func`` is not callable | +| ``apply`` | P | | ``N`` if ``func`` is not callable | | | | | or ``result_type`` is given or ``args`` and | | | | | ``kwargs`` contain DataFrame or Series | | | | | ``N`` if ``func`` maps to different column labels. | @@ -471,8 +471,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``to_xml`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``transform`` | P | | Only callable and string parameters are supported.| -| | | | list and dict parameters are not supported. | +| ``transform`` | P | | ``Y`` if ``func`` is callable. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``transpose`` | P | | See ``T`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 68b1935da96..3afe671aee7 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -82,9 +82,9 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``snap`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | | +| ``tz_convert`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | | +| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 797ef3bbd59..95d9610202b 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -38,8 +38,7 @@ Data manipulations +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. | -| | | , ``left_index``, ``right_index``| | +| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. | | | | , ``suffixes``, ``tolerance`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge_ordered`` | N | | | diff --git a/docs/source/modin/supported/series_dt_supported.rst b/docs/source/modin/supported/series_dt_supported.rst index 3377a3d64e2..68853871ea6 100644 --- a/docs/source/modin/supported/series_dt_supported.rst +++ b/docs/source/modin/supported/series_dt_supported.rst @@ -80,9 +80,10 @@ the method in the left column. +-----------------------------+---------------------------------+----------------------------------------------------+ | ``to_pydatetime`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | +| ``tz_localize`` | P | ``N`` if `ambiguous` or `nonexistent` are set to a | +| | | non-default value. | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | +| ``tz_convert`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``normalize`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/timedelta_index_supported.rst b/docs/source/modin/supported/timedelta_index_supported.rst index 49dfcb305e4..f7a34c3552c 100644 --- a/docs/source/modin/supported/timedelta_index_supported.rst +++ b/docs/source/modin/supported/timedelta_index_supported.rst @@ -44,7 +44,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ | ``ceil`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ -| ``mean`` | N | | | +| ``mean`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ | ``total_seconds`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ diff --git a/recipe/meta.yaml b/recipe/meta.yaml index cf1f2c9ad70..9560f4a4408 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -1,5 +1,5 @@ {% set name = "snowflake-snowpark-python" %} -{% set version = "1.21.1" %} +{% set version = "1.22.1" %} package: name: {{ name|lower }} diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 76e91b7da92..d8622299ea9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if ( - len(logical_plan.output) * len(logical_plan.data) - < ARRAY_BIND_THRESHOLD - ): + if not logical_plan.is_large_local_data: return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 3ed969caada..22591f55e47 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -29,6 +30,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.left, self.right) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a2d21db4eb2..a7cb5fd97a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -35,6 +35,13 @@ def derive_dependent_columns( *expressions: "Optional[Expression]", ) -> Optional[AbstractSet[str]]: + """ + Given set of expressions, derive the set of columns that the expressions dependents on. + + Note, the returned dependent columns is a set without duplication. For example, given expression + concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has + occurred in the given expression twice. + """ result = set() for exp in expressions: if exp is not None: @@ -48,6 +55,23 @@ def derive_dependent_columns( return result +def derive_dependent_columns_with_duplication( + *expressions: "Optional[Expression]", +) -> List[str]: + """ + Given set of expressions, derive the list of columns that the expression dependents on. + + Note, the returned columns will have duplication if the column occurred more than once in + the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result + [col1, col1, col2], where col1 occurred twice in the result. + """ + result = [] + for exp in expressions: + if exp is not None: + result.extend(exp.dependent_column_names_with_duplication()) + return result + + class Expression: """Consider removing attributes, and adding properties and methods. A subclass of Expression may have no child, one child, or multiple children. @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY + def dependent_column_names_with_duplication(self) -> List[str]: + return [] + @property def pretty_name(self) -> str: """Returns a user-facing string representation of this expression's name. @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + def dependent_column_names_with_duplication(self) -> List[str]: + return list(COLUMN_DEPENDENCY_DOLLAR) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.expressions) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.columns, *self.values) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN @@ -212,6 +248,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + def dependent_column_names_with_duplication(self) -> List[str]: + return [self.name] + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + derive_dependent_columns_with_duplication(*self.expressions) + if self.expressions + else [] # we currently do not handle * dependency + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} @@ -278,6 +324,14 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + [] + if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL) + or (self._dependent_column_names is None) + else list(self._dependent_column_names) + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern @@ -400,6 +457,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -510,6 +579,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) + @property def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) @@ -549,13 +624,21 @@ def __init__( self.branches = branches self.else_value = else_value - def dependent_column_names(self) -> Optional[AbstractSet[str]]: + @property + def _child_expressions(self) -> List[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) if self.else_value is not None: exps.append(self.else_value) - return derive_dependent_columns(*exps) + + return exps + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self._child_expressions) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self._child_expressions) @property def plan_node_category(self) -> PlanNodeCategory: @@ -602,6 +685,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.col) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.exprs) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 84cd63fd87d..012940471d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.group_by_exprs) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + def dependent_column_names_with_duplication(self) -> List[str]: + flattened_args = [exp for sublist in self.args for exp in sublist] + return derive_dependent_columns_with_duplication(*flattened_args) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index e3e032cd94b..aa8730dcf7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -144,10 +144,27 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def is_large_local_data(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.is_large_local_data: + # When the number of literals exceeds the threshold, we generate 3 queries: + # 1. create table query + # 2. insert into table query + # 3. select * from table query + # We only consider the complexity from the final select * query since other queries + # are built based on it. + return { + PlanNodeCategory.COLUMN: 1, + } + + # If we stay under the threshold, we generate a single query: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) - # TODO: use ARRAY_BIND_THRESHOLD return { PlanNodeCategory.COLUMN: len(self.output), PlanNodeCategory.LITERAL: len(self.data) * len(self.output), diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..82451245e4c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, List, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) @@ -55,3 +56,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5886e11069..1ae08e8fde2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -36,6 +37,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 69db3f265ce..4381c4a2e22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -71,6 +72,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.lower, self.upper) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + *self.partition_spec, *self.order_spec, self.frame_spec + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec @@ -138,6 +147,11 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + self.window_function, self.window_spec + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW @@ -171,6 +185,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.default) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 836628345aa..8d16383a4ce 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -58,11 +58,6 @@ ) from snowflake.snowpark.session import Session -# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT -# in Snowflake. This is the limit where we start seeing compilation errors. -COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 -COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 - _logger = logging.getLogger(__name__) @@ -123,6 +118,12 @@ def __init__( self._query_generator = query_generator self.logical_plans = logical_plans self._parent_map = defaultdict(set) + self.complexity_score_lower_bound = ( + session.large_query_breakdown_complexity_bounds[0] + ) + self.complexity_score_upper_bound = ( + session.large_query_breakdown_complexity_bounds[1] + ) def apply(self) -> List[LogicalPlan]: if is_active_transaction(self.session): @@ -183,13 +184,13 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]: complexity_score = get_complexity_score(root.cumulative_node_complexity) _logger.debug(f"Complexity score for root {type(root)} is: {complexity_score}") - if complexity_score <= COMPLEXITY_SCORE_UPPER_BOUND: + if complexity_score <= self.complexity_score_upper_bound: # Skip optimization if the complexity score is within the upper bound. return [root] plans = [] # TODO: SNOW-1617634 Have a one pass algorithm to find the valid node for partitioning - while complexity_score > COMPLEXITY_SCORE_UPPER_BOUND: + while complexity_score > self.complexity_score_upper_bound: child = self._find_node_to_breakdown(root) if child is None: _logger.debug( @@ -277,7 +278,9 @@ def _is_node_valid_to_breakdown(self, node: LogicalPlan) -> Tuple[bool, int]: """ score = get_complexity_score(node.cumulative_node_complexity) valid_node = ( - COMPLEXITY_SCORE_LOWER_BOUND < score < COMPLEXITY_SCORE_UPPER_BOUND + self.complexity_score_lower_bound + < score + < self.complexity_score_upper_bound ) and self._is_node_pipeline_breaker(node) if valid_node: _logger.debug( diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index bef53f0f389..3e6dba71be4 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -3,8 +3,12 @@ # import copy +import time from typing import Dict, List +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + get_complexity_score, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( PlanQueryType, Query, @@ -17,7 +21,11 @@ from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import ( RepeatedSubqueryElimination, ) +from snowflake.snowpark._internal.compiler.telemetry_constants import ( + CompilationStageTelemetryField, +) from snowflake.snowpark._internal.compiler.utils import create_query_generator +from snowflake.snowpark._internal.telemetry import TelemetryField from snowflake.snowpark.mock._connection import MockServerConnection @@ -68,24 +76,71 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: if self.should_start_query_compilation(): # preparation for compilation # 1. make a copy of the original plan + start_time = time.time() + complexity_score_before_compilation = get_complexity_score( + self._plan.cumulative_node_complexity + ) logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)] + deep_copy_end_time = time.time() + # 2. create a code generator with the original plan query_generator = create_query_generator(self._plan) - # apply each optimizations if needed + # 3. apply each optimizations if needed + # CTE optimization + cte_start_time = time.time() if self._plan.session.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) logical_plans = repeated_subquery_eliminator.apply() + + cte_end_time = time.time() + complexity_scores_after_cte = [ + get_complexity_score(logical_plan.cumulative_node_complexity) + for logical_plan in logical_plans + ] + + # Large query breakdown if self._plan.session.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( self._plan.session, query_generator, logical_plans ) logical_plans = large_query_breakdown.apply() - # do a final pass of code generation - return query_generator.generate_queries(logical_plans) + large_query_breakdown_end_time = time.time() + complexity_scores_after_large_query_breakdown = [ + get_complexity_score(logical_plan.cumulative_node_complexity) + for logical_plan in logical_plans + ] + + # 4. do a final pass of code generation + queries = query_generator.generate_queries(logical_plans) + + # log telemetry data + deep_copy_time = deep_copy_end_time - start_time + cte_time = cte_end_time - cte_start_time + large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time + total_time = time.time() - start_time + session = self._plan.session + summary_value = { + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, + CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation, + CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte, + CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown, + } + session._conn._telemetry_client.send_query_compilation_summary_telemetry( + session_id=session.session_id, + plan_uuid=self._plan.uuid, + compilation_stage_summary=summary_value, + ) + return queries else: final_plan = self._plan if self._plan.session.cte_optimization_enabled: diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 3c1f0d4fc5d..be61a1ac924 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -6,10 +6,28 @@ class CompilationStageTelemetryField(Enum): + # types TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED = ( "snowpark_large_query_breakdown_optimization_skipped" ) + TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics" + TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = ( + "snowpark_large_query_breakdown_update_complexity_bounds" + ) + + # keys KEY_REASON = "reason" + PLAN_UUID = "plan_uuid" + TIME_TAKEN_FOR_COMPILATION = "time_taken_for_compilation_sec" + TIME_TAKEN_FOR_DEEP_COPY_PLAN = "time_taken_for_deep_copy_plan_sec" + TIME_TAKEN_FOR_CTE_OPTIMIZATION = "time_taken_for_cte_optimization_sec" + TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN = "time_taken_for_large_query_breakdown_sec" + COMPLEXITY_SCORE_BOUNDS = "complexity_score_bounds" + COMPLEXITY_SCORE_BEFORE_COMPILATION = "complexity_score_before_compilation" + COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization" + COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN = ( + "complexity_score_after_large_query_breakdown" + ) class SkipLargeQueryBreakdownCategory(Enum): diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 05488398d16..025eb57c540 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -79,6 +79,20 @@ class TelemetryField(Enum): QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" QUERY_PLAN_COMPLEXITY = "query_plan_complexity" + # temp table cleanup + TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup" + NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned" + NUM_TEMP_TABLES_CREATED = "num_temp_tables_created" + TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled" + TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = ( + "snowpark_temp_table_cleanup_abnormal_exception" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = ( + "temp_table_cleanup_abnormal_exception_table_name" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = ( + "temp_table_cleanup_abnormal_exception_message" + ) # These DataFrame APIs call other DataFrame APIs @@ -168,6 +182,11 @@ def wrap(*args, **kwargs): ]._session.sql_simplifier_enabled try: api_calls[0][TelemetryField.QUERY_PLAN_HEIGHT.value] = plan.plan_height + # The uuid for df._select_statement can be different from df._plan. Since plan + # can take both values, we cannot use plan.uuid. We always use df._plan.uuid + # to track the queries. + uuid = args[0]._plan.uuid + api_calls[0][CompilationStageTelemetryField.PLAN_UUID.value] = uuid api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes @@ -369,7 +388,7 @@ def send_sql_simplifier_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.SQL_SIMPLIFIER_ENABLED.value: True, + TelemetryField.SQL_SIMPLIFIER_ENABLED.value: sql_simplifier_enabled, }, } self.send(message) @@ -423,7 +442,25 @@ def send_large_query_breakdown_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: True, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: value, + }, + } + self.send(message) + + def send_query_compilation_summary_telemetry( + self, + session_id: int, + plan_uuid: str, + compilation_stage_summary: Dict[str, Any], + ) -> None: + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, + **compilation_stage_summary, }, } self.send(message) @@ -441,3 +478,60 @@ def send_large_query_optimization_skipped_telemetry( }, } self.send(message) + + def send_temp_table_cleanup_telemetry( + self, + session_id: str, + temp_table_cleaner_enabled: bool, + num_temp_tables_cleaned: int, + num_temp_tables_created: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled, + TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned, + TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created, + }, + } + self.send(message) + + def send_temp_table_cleanup_abnormal_exception_telemetry( + self, + session_id: str, + table_name: str, + exception_message: str, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message, + }, + } + self.send(message) + + def send_large_query_breakdown_update_complexity_bounds( + self, session_id: int, lower_bound: int, upper_bound: int + ): + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_DATA.value: { + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( + lower_bound, + upper_bound, + ), + }, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index b9055c6fc58..4fa17498d34 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -4,9 +4,7 @@ import logging import weakref from collections import defaultdict -from queue import Empty, Queue -from threading import Event, Thread -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable @@ -33,12 +31,6 @@ def __init__(self, session: "Session") -> None: # to its reference count for later temp table management # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) - # unused temp table will be put into the queue for cleanup - self.queue: Queue = Queue() - # thread for removing temp tables (running DROP TABLE sql) - self.cleanup_thread: Optional[Thread] = None - # An event managing a flag that indicates whether the cleaner is started - self.stop_event = Event() def add(self, table: SnowflakeTable) -> None: self.ref_count_map[table.name] += 1 @@ -46,61 +38,60 @@ def add(self, table: SnowflakeTable) -> None: # and this table will be dropped finally _ = weakref.finalize(table, self._delete_ref_count, table.name) - def _delete_ref_count(self, name: str) -> None: + def _delete_ref_count(self, name: str) -> None: # pragma: no cover """ Decrements the reference count of a temporary table, and if the count reaches zero, puts this table in the queue for cleanup. """ self.ref_count_map[name] -= 1 if self.ref_count_map[name] == 0: - self.ref_count_map.pop(name) - # clean up - self.queue.put(name) + if self.session.auto_clean_up_temp_table_enabled: + self.drop_table(name) elif self.ref_count_map[name] < 0: logging.debug( f"Unexpected reference count {self.ref_count_map[name]} for table {name}" ) - def process_cleanup(self) -> None: - while not self.stop_event.is_set(): - try: - # it's non-blocking after timeout and become interruptable with stop_event - # it will raise an `Empty` exception if queue is empty after timeout, - # then we catch this exception and avoid breaking loop - table_name = self.queue.get(timeout=1) - self.drop_table(table_name) - except Empty: - continue - - def drop_table(self, name: str) -> None: + def drop_table(self, name: str) -> None: # pragma: no cover common_log_text = f"temp table {name} in session {self.session.session_id}" - logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}") + logging.debug(f"Ready to drop {common_log_text}") + query_id = None try: - # TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported - with self.session._conn._conn.cursor() as cursor: - cursor.execute( - f"drop table if exists {name} /* internal query to drop unused temp table */", - _statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}, + async_job = self.session.sql( + f"drop table if exists {name} /* internal query to drop unused temp table */", + )._internal_collect_with_tag_no_telemetry( + block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name} + ) + query_id = async_job.query_id + logging.debug(f"Dropping {common_log_text} with query id {query_id}") + except Exception as ex: # pragma: no cover + warning_message = f"Failed to drop {common_log_text}, exception: {ex}" + logging.warning(warning_message) + if query_id is None: + # If no query_id is available, it means the query haven't been accepted by gs, + # and it won't occur in our job_etl_view, send a separate telemetry for recording. + self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry( + self.session.session_id, + name, + str(ex), ) - logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}") - except Exception as ex: - logging.warning( - f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}" - ) # pragma: no cover - - def is_alive(self) -> bool: - return self.cleanup_thread is not None and self.cleanup_thread.is_alive() - - def start(self) -> None: - self.stop_event.clear() - if not self.is_alive(): - self.cleanup_thread = Thread(target=self.process_cleanup) - self.cleanup_thread.start() def stop(self) -> None: """ - The cleaner will stop immediately and leave unfinished temp tables in the queue. + Stops the cleaner (no-op) and sends the telemetry. """ - self.stop_event.set() - if self.is_alive(): - self.cleanup_thread.join() + self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry( + self.session.session_id, + temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled, + num_temp_tables_cleaned=self.num_temp_tables_cleaned, + num_temp_tables_created=self.num_temp_tables_created, + ) + + @property + def num_temp_tables_created(self) -> int: + return len(self.ref_count_map) + + @property + def num_temp_tables_cleaned(self) -> int: + # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled + return sum(v == 0 for v in self.ref_count_map.values()) diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index 9e8d4d0d721..b384931cb89 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -6,6 +6,7 @@ import functools import json import logging +import threading import uuid from copy import copy from decimal import Decimal @@ -91,35 +92,39 @@ def __init__(self, conn: "MockServerConnection") -> None: self.table_registry = {} self.view_registry = {} self.conn = conn + self._lock = self.conn.get_lock() def is_existing_table(self, name: Union[str, Iterable[str]]) -> bool: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - qualified_name = get_fully_qualified_name( - name, current_schema, current_database - ) - return qualified_name in self.table_registry + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + qualified_name = get_fully_qualified_name( + name, current_schema, current_database + ) + return qualified_name in self.table_registry def is_existing_view(self, name: Union[str, Iterable[str]]) -> bool: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - qualified_name = get_fully_qualified_name( - name, current_schema, current_database - ) - return qualified_name in self.view_registry + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + qualified_name = get_fully_qualified_name( + name, current_schema, current_database + ) + return qualified_name in self.view_registry def read_table(self, name: Union[str, Iterable[str]]) -> TableEmulator: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - qualified_name = get_fully_qualified_name( - name, current_schema, current_database - ) - if qualified_name in self.table_registry: - return copy(self.table_registry[qualified_name]) - else: - raise SnowparkLocalTestingException( - f"Object '{name}' does not exist or not authorized." + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + qualified_name = get_fully_qualified_name( + name, current_schema, current_database ) + if qualified_name in self.table_registry: + return copy(self.table_registry[qualified_name]) + else: + raise SnowparkLocalTestingException( + f"Object '{name}' does not exist or not authorized." + ) def write_table( self, @@ -128,127 +133,155 @@ def write_table( mode: SaveMode, column_names: Optional[List[str]] = None, ) -> List[Row]: - for column in table.columns: - if not table[column].sf_type.nullable and table[column].isnull().any(): - raise SnowparkLocalTestingException( - "NULL result in a non-nullable column" - ) - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - table = copy(table) - if mode == SaveMode.APPEND: - if name in self.table_registry: - target_table = self.table_registry[name] - input_schema = table.columns.to_list() - existing_schema = target_table.columns.to_list() - - if not column_names: # append with column_order being index - if len(input_schema) != len(existing_schema): - raise SnowparkLocalTestingException( - f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}" - ) - # temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1 - table.columns = range(table.shape[1]) - target_table.columns = range(target_table.shape[1]) - else: # append with column_order being name - if invalid_cols := set(input_schema) - set(existing_schema): - identifiers = "', '".join( - unquote_if_quoted(id) for id in invalid_cols - ) - raise SnowparkLocalTestingException( - f"table contains invalid identifier '{identifiers}'" - ) - invalid_non_nullable_cols = [] - for missing_col in set(existing_schema) - set(input_schema): - if target_table[missing_col].sf_type.nullable: - table[missing_col] = None - table.sf_types[missing_col] = target_table[ - missing_col - ].sf_type - else: - invalid_non_nullable_cols.append(missing_col) - if invalid_non_nullable_cols: - identifiers = "', '".join( - unquote_if_quoted(id) - for id in invalid_non_nullable_cols - ) - raise SnowparkLocalTestingException( - f"NULL result in a non-nullable column '{identifiers}'" - ) - - self.table_registry[name] = pandas.concat( - [target_table, table], ignore_index=True - ) - self.table_registry[name].columns = existing_schema - self.table_registry[name].sf_types = target_table.sf_types - else: - self.table_registry[name] = table - elif mode == SaveMode.IGNORE: - if name not in self.table_registry: - self.table_registry[name] = table - elif mode == SaveMode.OVERWRITE: - self.table_registry[name] = table - elif mode == SaveMode.ERROR_IF_EXISTS: - if name in self.table_registry: - raise SnowparkLocalTestingException(f"Table {name} already exists") - else: - self.table_registry[name] = table - elif mode == SaveMode.TRUNCATE: - if name in self.table_registry: - target_table = self.table_registry[name] - input_schema = set(table.columns.to_list()) - existing_schema = set(target_table.columns.to_list()) - # input is a subset of existing schema and all missing columns are nullable - if input_schema.issubset(existing_schema) and all( - target_table[col].sf_type.nullable - for col in set(existing_schema - input_schema) + with self._lock: + for column in table.columns: + if ( + not table[column].sf_type.nullable + and table[column].isnull().any() ): - for col in set(existing_schema - input_schema): - table[col] = ColumnEmulator( - data=[None] * table.shape[0], - sf_type=target_table[col].sf_type, - dtype=object, - ) + raise SnowparkLocalTestingException( + "NULL result in a non-nullable column" + ) + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + table = copy(table) + if mode == SaveMode.APPEND: + if name in self.table_registry: + target_table = self.table_registry[name] + input_schema = table.columns.to_list() + existing_schema = target_table.columns.to_list() + + if not column_names: # append with column_order being index + if len(input_schema) != len(existing_schema): + raise SnowparkLocalTestingException( + f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}" + ) + # temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1 + table.columns = range(table.shape[1]) + target_table.columns = range(target_table.shape[1]) + else: # append with column_order being name + if invalid_cols := set(input_schema) - set(existing_schema): + identifiers = "', '".join( + unquote_if_quoted(id) for id in invalid_cols + ) + raise SnowparkLocalTestingException( + f"table contains invalid identifier '{identifiers}'" + ) + invalid_non_nullable_cols = [] + for missing_col in set(existing_schema) - set(input_schema): + if target_table[missing_col].sf_type.nullable: + table[missing_col] = None + table.sf_types[missing_col] = target_table[ + missing_col + ].sf_type + else: + invalid_non_nullable_cols.append(missing_col) + if invalid_non_nullable_cols: + identifiers = "', '".join( + unquote_if_quoted(id) + for id in invalid_non_nullable_cols + ) + raise SnowparkLocalTestingException( + f"NULL result in a non-nullable column '{identifiers}'" + ) + + self.table_registry[name] = pandas.concat( + [target_table, table], ignore_index=True + ) + self.table_registry[name].columns = existing_schema + self.table_registry[name].sf_types = target_table.sf_types else: + self.table_registry[name] = table + elif mode == SaveMode.IGNORE: + if name not in self.table_registry: + self.table_registry[name] = table + elif mode == SaveMode.OVERWRITE: + self.table_registry[name] = table + elif mode == SaveMode.ERROR_IF_EXISTS: + if name in self.table_registry: raise SnowparkLocalTestingException( - f"Cannot truncate because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}" + f"Table {name} already exists" ) - table.sf_types_by_col_index = target_table.sf_types_by_col_index - table = table.reindex(columns=target_table.columns) - self.table_registry[name] = table - else: - raise SnowparkLocalTestingException(f"Unrecognized mode: {mode}") - return [ - Row(status=f"Table {name} successfully created.") - ] # TODO: match message + else: + self.table_registry[name] = table + elif mode == SaveMode.TRUNCATE: + if name in self.table_registry: + target_table = self.table_registry[name] + input_schema = set(table.columns.to_list()) + existing_schema = set(target_table.columns.to_list()) + # input is a subset of existing schema and all missing columns are nullable + if input_schema.issubset(existing_schema) and all( + target_table[col].sf_type.nullable + for col in set(existing_schema - input_schema) + ): + for col in set(existing_schema - input_schema): + table[col] = ColumnEmulator( + data=[None] * table.shape[0], + sf_type=target_table[col].sf_type, + dtype=object, + ) + else: + raise SnowparkLocalTestingException( + f"Cannot truncate because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}" + ) + table.sf_types_by_col_index = target_table.sf_types_by_col_index + table = table.reindex(columns=target_table.columns) + self.table_registry[name] = table + else: + raise SnowparkLocalTestingException(f"Unrecognized mode: {mode}") + return [ + Row(status=f"Table {name} successfully created.") + ] # TODO: match message def drop_table(self, name: Union[str, Iterable[str]]) -> None: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - if name in self.table_registry: - self.table_registry.pop(name) + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + if name in self.table_registry: + self.table_registry.pop(name) def create_or_replace_view( self, execution_plan: MockExecutionPlan, name: Union[str, Iterable[str]] ): - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - self.view_registry[name] = execution_plan + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + self.view_registry[name] = execution_plan def get_review(self, name: Union[str, Iterable[str]]) -> MockExecutionPlan: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - if name in self.view_registry: - return self.view_registry[name] - raise SnowparkLocalTestingException(f"View {name} does not exist") + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + if name in self.view_registry: + return self.view_registry[name] + raise SnowparkLocalTestingException(f"View {name} does not exist") + + def read_view_if_exists( + self, name: Union[str, Iterable[str]] + ) -> Optional[MockExecutionPlan]: + """Method to atomically read a view if it exists. Returns None if the view does not exist.""" + with self._lock: + if self.is_existing_view(name): + return self.get_review(name) + return None + + def read_table_if_exists( + self, name: Union[str, Iterable[str]] + ) -> Optional[TableEmulator]: + """Method to atomically read a table if it exists. Returns None if the table does not exist.""" + with self._lock: + if self.is_existing_table(name): + return self.read_table(name) + return None def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self._conn = MockedSnowflakeConnection() self._cursor = Mock() + self._lock = threading.RLock() self._lower_case_parameters = {} self.remove_query_listener = Mock() self.add_query_listener = Mock() @@ -301,7 +334,7 @@ def log_not_supported_error( warning_logger: Optional[logging.Logger] = None, ): """ - send telemetry to oob servie, can raise error or logging a warning based upon the input + send telemetry to oob service, can raise error or logging a warning based upon the input Args: external_feature_name: customer facing feature name, this information is used to raise error @@ -323,25 +356,31 @@ def log_not_supported_error( def _get_client_side_session_parameter(self, name: str, default_value: Any) -> Any: # mock implementation - return ( - self._conn._session_parameters.get(name, default_value) - if self._conn._session_parameters - else default_value - ) + with self._lock: + return ( + self._conn._session_parameters.get(name, default_value) + if self._conn._session_parameters + else default_value + ) def get_session_id(self) -> int: return 1 + def get_lock(self): + return self._lock + def close(self) -> None: - if self._conn: - self._conn.close() + with self._lock: + if self._conn: + self._conn.close() def is_closed(self) -> bool: return self._conn.is_closed() def _get_current_parameter(self, param: str, quoted: bool = True) -> Optional[str]: try: - name = getattr(self, f"_active_{param}", None) + with self._lock: + name = getattr(self, f"_active_{param}", None) if name and len(name) >= 2 and name[0] == name[-1] == '"': # it is a quoted identifier, return the original value return name diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index edf9ffc68b3..3842f6fda34 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -10,6 +10,7 @@ import operator import re import string +import threading from decimal import Decimal from functools import partial, reduce from numbers import Real @@ -130,14 +131,17 @@ def __call__(self, *args, input_data=None, row_number=None, **kwargs): class MockedFunctionRegistry: _instance = None + _lock_init = threading.Lock() def __init__(self) -> None: self._registry = dict() + self._lock = threading.RLock() @classmethod def get_or_create(cls) -> "MockedFunctionRegistry": - if cls._instance is None: - cls._instance = MockedFunctionRegistry() + with cls._lock_init: + if cls._instance is None: + cls._instance = MockedFunctionRegistry() return cls._instance def get_function( @@ -151,10 +155,11 @@ def get_function( distinct = func.is_distinct func_name = func_name.lower() - if func_name not in self._registry: - return None + with self._lock: + if func_name not in self._registry: + return None - function = self._registry[func_name] + function = self._registry[func_name] return function.distinct if distinct else function @@ -169,7 +174,8 @@ def register( snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__ ) mocked_function = MockedFunction(name, func_implementation, *args, **kwargs) - self._registry[name] = mocked_function + with self._lock: + self._registry[name] = mocked_function return mocked_function def unregister( @@ -180,8 +186,9 @@ def unregister( snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__ ) - if name in self._registry: - del self._registry[name] + with self._lock: + if name in self._registry: + del self._registry[name] class LocalTimezone: diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index 11e54802eea..aa86b2598d6 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -357,18 +357,21 @@ def handle_function_expression( current_row=None, ): func = MockedFunctionRegistry.get_or_create().get_function(exp) + connection_lock = analyzer.session._conn.get_lock() if func is None: - current_schema = analyzer.session.get_current_schema() - current_database = analyzer.session.get_current_database() + with connection_lock: + current_schema = analyzer.session.get_current_schema() + current_database = analyzer.session.get_current_database() udf_name = get_fully_qualified_name(exp.name, current_schema, current_database) # If udf name in the registry then this is a udf, not an actual function - if udf_name in analyzer.session.udf._registry: - exp.udf_name = udf_name - return handle_udf_expression( - exp, input_data, analyzer, expr_to_alias, current_row - ) + with connection_lock: + if udf_name in analyzer.session.udf._registry: + exp.udf_name = udf_name + return handle_udf_expression( + exp, input_data, analyzer, expr_to_alias, current_row + ) if exp.api_call_source == "functions.call_udf": raise SnowparkLocalTestingException( @@ -463,9 +466,12 @@ def handle_udf_expression( ): udf_registry = analyzer.session.udf udf_name = exp.udf_name - udf = udf_registry.get_udf(udf_name) + connection_lock = analyzer.session._conn.get_lock() + with connection_lock: + udf = udf_registry.get_udf(udf_name) + udf_imports = udf_registry.get_udf_imports(udf_name) - with ImportContext(udf_registry.get_udf_imports(udf_name)): + with ImportContext(udf_imports): # Resolve handler callable if type(udf.func) is tuple: module_name, handler_name = udf.func @@ -556,6 +562,7 @@ def execute_mock_plan( analyzer = plan.analyzer entity_registry = analyzer.session._conn.entity_registry + connection_lock = analyzer.session._conn.get_lock() if isinstance(source_plan, SnowflakeValues): table = TableEmulator( @@ -728,18 +735,20 @@ def execute_mock_plan( return res_df if isinstance(source_plan, MockSelectableEntity): entity_name = source_plan.entity.name - if entity_registry.is_existing_table(entity_name): - return entity_registry.read_table(entity_name) - elif entity_registry.is_existing_view(entity_name): - execution_plan = entity_registry.get_review(entity_name) + table = entity_registry.read_table_if_exists(entity_name) + if table is not None: + return table + + execution_plan = entity_registry.read_view_if_exists(entity_name) + if execution_plan is not None: res_df = execute_mock_plan(execution_plan, expr_to_alias) return res_df - else: - db_schme_table = parse_table_name(entity_name) - table = ".".join([part.strip("\"'") for part in db_schme_table[:3]]) - raise SnowparkLocalTestingException( - f"Object '{table}' does not exist or not authorized." - ) + + db_schema_table = parse_table_name(entity_name) + table = ".".join([part.strip("\"'") for part in db_schema_table[:3]]) + raise SnowparkLocalTestingException( + f"Object '{table}' does not exist or not authorized." + ) if isinstance(source_plan, Aggregate): child_rf = execute_mock_plan(source_plan.child, expr_to_alias) if ( @@ -1111,28 +1120,30 @@ def outer_join(base_df): ) if isinstance(source_plan, SnowflakeTable): entity_name = source_plan.name - if entity_registry.is_existing_table(entity_name): - return entity_registry.read_table(entity_name) - elif entity_registry.is_existing_view(entity_name): - execution_plan = entity_registry.get_review(entity_name) + table = entity_registry.read_table_if_exists(entity_name) + if table is not None: + return table + + execution_plan = entity_registry.read_view_if_exists(entity_name) + if execution_plan is not None: res_df = execute_mock_plan(execution_plan, expr_to_alias) return res_df - else: - obj_name_tuple = parse_table_name(entity_name) - obj_name = obj_name_tuple[-1] - obj_schema = ( - obj_name_tuple[-2] - if len(obj_name_tuple) > 1 - else analyzer.session.get_current_schema() - ) - obj_database = ( - obj_name_tuple[-3] - if len(obj_name_tuple) > 2 - else analyzer.session.get_current_database() - ) - raise SnowparkLocalTestingException( - f"Object '{obj_database[1:-1]}.{obj_schema[1:-1]}.{obj_name[1:-1]}' does not exist or not authorized." - ) + + obj_name_tuple = parse_table_name(entity_name) + obj_name = obj_name_tuple[-1] + obj_schema = ( + obj_name_tuple[-2] + if len(obj_name_tuple) > 1 + else analyzer.session.get_current_schema() + ) + obj_database = ( + obj_name_tuple[-3] + if len(obj_name_tuple) > 2 + else analyzer.session.get_current_database() + ) + raise SnowparkLocalTestingException( + f"Object '{obj_database[1:-1]}.{obj_schema[1:-1]}.{obj_name[1:-1]}' does not exist or not authorized." + ) if isinstance(source_plan, Sample): res_df = execute_mock_plan(source_plan.child, expr_to_alias) @@ -1159,272 +1170,283 @@ def outer_join(base_df): return from_df if isinstance(source_plan, TableUpdate): - target = entity_registry.read_table(source_plan.table_name) - ROW_ID = "row_id_" + generate_random_alphanumeric() - target.insert(0, ROW_ID, range(len(target))) + # since we are modifying the table, we need to ensure that no other thread + # reads the table until it is updated + with connection_lock: + target = entity_registry.read_table(source_plan.table_name) + ROW_ID = "row_id_" + generate_random_alphanumeric() + target.insert(0, ROW_ID, range(len(target))) + + if source_plan.source_data: + # Calculate cartesian product + source = execute_mock_plan(source_plan.source_data, expr_to_alias) + cartesian_product = target.merge(source, on=None, how="cross") + cartesian_product.sf_types.update(target.sf_types) + cartesian_product.sf_types.update(source.sf_types) + intermediate = cartesian_product + else: + intermediate = target - if source_plan.source_data: - # Calculate cartesian product - source = execute_mock_plan(source_plan.source_data, expr_to_alias) - cartesian_product = target.merge(source, on=None, how="cross") - cartesian_product.sf_types.update(target.sf_types) - cartesian_product.sf_types.update(source.sf_types) - intermediate = cartesian_product - else: - intermediate = target + if source_plan.condition: + # Select rows to be updated based on condition + condition = calculate_expression( + source_plan.condition, intermediate, analyzer, expr_to_alias + ).fillna(value=False) - if source_plan.condition: - # Select rows to be updated based on condition - condition = calculate_expression( - source_plan.condition, intermediate, analyzer, expr_to_alias - ).fillna(value=False) - - matched = target.apply(tuple, 1).isin( - intermediate[condition][target.columns].apply(tuple, 1) + matched = target.apply(tuple, 1).isin( + intermediate[condition][target.columns].apply(tuple, 1) + ) + matched.sf_type = ColumnType(BooleanType(), True) + matched_rows = target[matched] + intermediate = intermediate[condition] + else: + matched_rows = target + + # Calculate multi_join + matched_count = intermediate[target.columns].value_counts(dropna=False)[ + matched_rows.apply(tuple, 1) + ] + multi_joins = matched_count.where(lambda x: x > 1).count() + + # Select rows that match the condition to be updated + rows_to_update = intermediate.drop_duplicates( + subset=matched_rows.columns, keep="first" + ).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update + drop=True ) - matched.sf_type = ColumnType(BooleanType(), True) - matched_rows = target[matched] - intermediate = intermediate[condition] - else: - matched_rows = target + rows_to_update.sf_types = intermediate.sf_types + + # Update rows in place + for attr, new_expr in source_plan.assignments.items(): + column_name = analyzer.analyze(attr, expr_to_alias) + target_index = target.loc[rows_to_update[ROW_ID]].index + new_val = calculate_expression( + new_expr, rows_to_update, analyzer, expr_to_alias + ) + new_val.index = target_index + target.loc[rows_to_update[ROW_ID], column_name] = new_val - # Calculate multi_join - matched_count = intermediate[target.columns].value_counts(dropna=False)[ - matched_rows.apply(tuple, 1) - ] - multi_joins = matched_count.where(lambda x: x > 1).count() + # Delete row_id + target = target.drop(ROW_ID, axis=1) - # Select rows that match the condition to be updated - rows_to_update = intermediate.drop_duplicates( - subset=matched_rows.columns, keep="first" - ).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update - drop=True - ) - rows_to_update.sf_types = intermediate.sf_types - - # Update rows in place - for attr, new_expr in source_plan.assignments.items(): - column_name = analyzer.analyze(attr, expr_to_alias) - target_index = target.loc[rows_to_update[ROW_ID]].index - new_val = calculate_expression( - new_expr, rows_to_update, analyzer, expr_to_alias + # Write result back to table + entity_registry.write_table( + source_plan.table_name, target, SaveMode.OVERWRITE ) - new_val.index = target_index - target.loc[rows_to_update[ROW_ID], column_name] = new_val - - # Delete row_id - target = target.drop(ROW_ID, axis=1) - - # Write result back to table - entity_registry.write_table(source_plan.table_name, target, SaveMode.OVERWRITE) return [Row(len(rows_to_update), multi_joins)] elif isinstance(source_plan, TableDelete): - target = entity_registry.read_table(source_plan.table_name) + # since we are modifying the table, we need to ensure that no other thread + # reads the table until it is updated + with connection_lock: + target = entity_registry.read_table(source_plan.table_name) + + if source_plan.source_data: + # Calculate cartesian product + source = execute_mock_plan(source_plan.source_data, expr_to_alias) + cartesian_product = target.merge(source, on=None, how="cross") + cartesian_product.sf_types.update(target.sf_types) + cartesian_product.sf_types.update(source.sf_types) + intermediate = cartesian_product + else: + intermediate = target + + # Select rows to keep based on condition + if source_plan.condition: + condition = calculate_expression( + source_plan.condition, intermediate, analyzer, expr_to_alias + ).fillna(value=False) + intermediate = intermediate[condition] + matched = target.apply(tuple, 1).isin( + intermediate[target.columns].apply(tuple, 1) + ) + matched.sf_type = ColumnType(BooleanType(), True) + rows_to_keep = target[~matched] + else: + rows_to_keep = target.head(0) - if source_plan.source_data: + # Write rows to keep to table registry + entity_registry.write_table( + source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE + ) + return [Row(len(target) - len(rows_to_keep))] + elif isinstance(source_plan, TableMerge): + # since we are modifying the table, we need to ensure that no other thread + # reads the table until it is updated + with connection_lock: + target = entity_registry.read_table(source_plan.table_name) + ROW_ID = "row_id_" + generate_random_alphanumeric() + SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric() # Calculate cartesian product - source = execute_mock_plan(source_plan.source_data, expr_to_alias) + source = execute_mock_plan(source_plan.source, expr_to_alias) + + # Insert row_id and source row_id + target.insert(0, ROW_ID, range(len(target))) + source.insert(0, SOURCE_ROW_ID, range(len(source))) + cartesian_product = target.merge(source, on=None, how="cross") cartesian_product.sf_types.update(target.sf_types) cartesian_product.sf_types.update(source.sf_types) - intermediate = cartesian_product - else: - intermediate = target - - # Select rows to keep based on condition - if source_plan.condition: - condition = calculate_expression( - source_plan.condition, intermediate, analyzer, expr_to_alias - ).fillna(value=False) - intermediate = intermediate[condition] - matched = target.apply(tuple, 1).isin( - intermediate[target.columns].apply(tuple, 1) + join_condition = calculate_expression( + source_plan.join_expr, cartesian_product, analyzer, expr_to_alias ) - matched.sf_type = ColumnType(BooleanType(), True) - rows_to_keep = target[~matched] - else: - rows_to_keep = target.head(0) + join_result = cartesian_product[join_condition] + join_result.sf_types = cartesian_product.sf_types + + # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if + # (1) A target row is selected to be updated with multiple values OR + # (2) A target row is selected to be both updated and deleted + + inserted_rows = [] + insert_clause_specified = ( + update_clause_specified + ) = delete_clause_specified = False + inserted_row_idx = set() # source_row_id + deleted_row_idx = set() + updated_row_idx = set() + for clause in source_plan.clauses: + if isinstance(clause, UpdateMergeExpression): + update_clause_specified = True + # Select rows to update + if clause.condition: + condition = calculate_expression( + clause.condition, join_result, analyzer, expr_to_alias + ).fillna(value=False) + rows_to_update = join_result[condition] + else: + rows_to_update = join_result - # Write rows to keep to table registry - entity_registry.write_table( - source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE - ) - return [Row(len(target) - len(rows_to_keep))] - elif isinstance(source_plan, TableMerge): - target = entity_registry.read_table(source_plan.table_name) - ROW_ID = "row_id_" + generate_random_alphanumeric() - SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric() - # Calculate cartesian product - source = execute_mock_plan(source_plan.source, expr_to_alias) - - # Insert row_id and source row_id - target.insert(0, ROW_ID, range(len(target))) - source.insert(0, SOURCE_ROW_ID, range(len(source))) - - cartesian_product = target.merge(source, on=None, how="cross") - cartesian_product.sf_types.update(target.sf_types) - cartesian_product.sf_types.update(source.sf_types) - join_condition = calculate_expression( - source_plan.join_expr, cartesian_product, analyzer, expr_to_alias - ) - join_result = cartesian_product[join_condition] - join_result.sf_types = cartesian_product.sf_types - - # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if - # (1) A target row is selected to be updated with multiple values OR - # (2) A target row is selected to be both updated and deleted - - inserted_rows = [] - insert_clause_specified = ( - update_clause_specified - ) = delete_clause_specified = False - inserted_row_idx = set() # source_row_id - deleted_row_idx = set() - updated_row_idx = set() - for clause in source_plan.clauses: - if isinstance(clause, UpdateMergeExpression): - update_clause_specified = True - # Select rows to update - if clause.condition: - condition = calculate_expression( - clause.condition, join_result, analyzer, expr_to_alias - ).fillna(value=False) - rows_to_update = join_result[condition] - else: - rows_to_update = join_result + rows_to_update = rows_to_update[ + ~rows_to_update[ROW_ID] + .isin(updated_row_idx.union(deleted_row_idx)) + .values + ] - rows_to_update = rows_to_update[ - ~rows_to_update[ROW_ID] - .isin(updated_row_idx.union(deleted_row_idx)) - .values - ] + # Update rows in place + for attr, new_expr in clause.assignments.items(): + column_name = analyzer.analyze(attr, expr_to_alias) + target_index = target.loc[rows_to_update[ROW_ID]].index + new_val = calculate_expression( + new_expr, rows_to_update, analyzer, expr_to_alias + ) + new_val.index = target_index + target.loc[rows_to_update[ROW_ID], column_name] = new_val + + # Update updated row id set + for _, row in rows_to_update.iterrows(): + updated_row_idx.add(row[ROW_ID]) + + elif isinstance(clause, DeleteMergeExpression): + delete_clause_specified = True + # Select rows to delete + if clause.condition: + condition = calculate_expression( + clause.condition, join_result, analyzer, expr_to_alias + ).fillna(value=False) + intermediate = join_result[condition] + else: + intermediate = join_result - # Update rows in place - for attr, new_expr in clause.assignments.items(): - column_name = analyzer.analyze(attr, expr_to_alias) - target_index = target.loc[rows_to_update[ROW_ID]].index - new_val = calculate_expression( - new_expr, rows_to_update, analyzer, expr_to_alias + matched = target.apply(tuple, 1).isin( + intermediate[target.columns].apply(tuple, 1) ) - new_val.index = target_index - target.loc[rows_to_update[ROW_ID], column_name] = new_val - - # Update updated row id set - for _, row in rows_to_update.iterrows(): - updated_row_idx.add(row[ROW_ID]) - - elif isinstance(clause, DeleteMergeExpression): - delete_clause_specified = True - # Select rows to delete - if clause.condition: - condition = calculate_expression( - clause.condition, join_result, analyzer, expr_to_alias - ).fillna(value=False) - intermediate = join_result[condition] - else: - intermediate = join_result + matched.sf_type = ColumnType(BooleanType(), True) - matched = target.apply(tuple, 1).isin( - intermediate[target.columns].apply(tuple, 1) - ) - matched.sf_type = ColumnType(BooleanType(), True) + # Update deleted row id set + for _, row in target[matched].iterrows(): + deleted_row_idx.add(row[ROW_ID]) - # Update deleted row id set - for _, row in target[matched].iterrows(): - deleted_row_idx.add(row[ROW_ID]) + # Delete rows in place + target = target[~matched] - # Delete rows in place - target = target[~matched] + elif isinstance(clause, InsertMergeExpression): + insert_clause_specified = True + # calculate unmatched rows in the source + matched = source.apply(tuple, 1).isin( + join_result[source.columns].apply(tuple, 1) + ) + matched.sf_type = ColumnType(BooleanType(), True) + unmatched_rows_in_source = source[~matched] + + # select unmatched rows that qualify the condition + if clause.condition: + condition = calculate_expression( + clause.condition, + unmatched_rows_in_source, + analyzer, + expr_to_alias, + ).fillna(value=False) + unmatched_rows_in_source = unmatched_rows_in_source[condition] + + # filter out the unmatched rows that have been inserted in previous clauses + unmatched_rows_in_source = unmatched_rows_in_source[ + ~unmatched_rows_in_source[SOURCE_ROW_ID] + .isin(inserted_row_idx) + .values + ] - elif isinstance(clause, InsertMergeExpression): - insert_clause_specified = True - # calculate unmatched rows in the source - matched = source.apply(tuple, 1).isin( - join_result[source.columns].apply(tuple, 1) - ) - matched.sf_type = ColumnType(BooleanType(), True) - unmatched_rows_in_source = source[~matched] + # update inserted row idx set + for _, row in unmatched_rows_in_source.iterrows(): + inserted_row_idx.add(row[SOURCE_ROW_ID]) - # select unmatched rows that qualify the condition - if clause.condition: - condition = calculate_expression( - clause.condition, - unmatched_rows_in_source, - analyzer, - expr_to_alias, - ).fillna(value=False) - unmatched_rows_in_source = unmatched_rows_in_source[condition] - - # filter out the unmatched rows that have been inserted in previous clauses - unmatched_rows_in_source = unmatched_rows_in_source[ - ~unmatched_rows_in_source[SOURCE_ROW_ID] - .isin(inserted_row_idx) - .values - ] + # Calculate rows to insert + rows_to_insert = TableEmulator( + [], columns=target.drop(ROW_ID, axis=1).columns, dtype=object + ) + rows_to_insert.sf_types = target.sf_types + if clause.keys: + # Keep track of specified columns + inserted_columns = set() + for k, v in zip(clause.keys, clause.values): + column_name = analyzer.analyze(k, expr_to_alias) + if column_name not in rows_to_insert.columns: + raise SnowparkLocalTestingException( + f"invalid identifier '{column_name}'" + ) + inserted_columns.add(column_name) + new_val = calculate_expression( + v, unmatched_rows_in_source, analyzer, expr_to_alias + ) + # pandas could do implicit type conversion, e.g. from datetime to timestamp + # reconstructing ColumnEmulator helps preserve the original date type + rows_to_insert[column_name] = ColumnEmulator( + new_val.values, + dtype=object, + sf_type=rows_to_insert[column_name].sf_type, + ) - # update inserted row idx set - for _, row in unmatched_rows_in_source.iterrows(): - inserted_row_idx.add(row[SOURCE_ROW_ID]) + # For unspecified columns, use None as default value + for unspecified_col in set(rows_to_insert.columns).difference( + inserted_columns + ): + rows_to_insert[unspecified_col].replace( + np.nan, None, inplace=True + ) - # Calculate rows to insert - rows_to_insert = TableEmulator( - [], columns=target.drop(ROW_ID, axis=1).columns, dtype=object - ) - rows_to_insert.sf_types = target.sf_types - if clause.keys: - # Keep track of specified columns - inserted_columns = set() - for k, v in zip(clause.keys, clause.values): - column_name = analyzer.analyze(k, expr_to_alias) - if column_name not in rows_to_insert.columns: + else: + if len(clause.values) != len(rows_to_insert.columns): raise SnowparkLocalTestingException( - f"invalid identifier '{column_name}'" + f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}" ) - inserted_columns.add(column_name) - new_val = calculate_expression( - v, unmatched_rows_in_source, analyzer, expr_to_alias - ) - # pandas could do implicit type conversion, e.g. from datetime to timestamp - # reconstructing ColumnEmulator helps preserve the original date type - rows_to_insert[column_name] = ColumnEmulator( - new_val.values, - dtype=object, - sf_type=rows_to_insert[column_name].sf_type, - ) - - # For unspecified columns, use None as default value - for unspecified_col in set(rows_to_insert.columns).difference( - inserted_columns - ): - rows_to_insert[unspecified_col].replace( - np.nan, None, inplace=True - ) - - else: - if len(clause.values) != len(rows_to_insert.columns): - raise SnowparkLocalTestingException( - f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}" - ) - for col, v in zip(rows_to_insert.columns, clause.values): - new_val = calculate_expression( - v, unmatched_rows_in_source, analyzer, expr_to_alias - ) - rows_to_insert[col] = new_val + for col, v in zip(rows_to_insert.columns, clause.values): + new_val = calculate_expression( + v, unmatched_rows_in_source, analyzer, expr_to_alias + ) + rows_to_insert[col] = new_val - inserted_rows.append(rows_to_insert) + inserted_rows.append(rows_to_insert) - # Remove inserted ROW ID column - target = target.drop(ROW_ID, axis=1) + # Remove inserted ROW ID column + target = target.drop(ROW_ID, axis=1) - # Process inserted rows - if inserted_rows: - res = pd.concat([target] + inserted_rows) - res.sf_types = target.sf_types - else: - res = target + # Process inserted rows + if inserted_rows: + res = pd.concat([target] + inserted_rows) + res.sf_types = target.sf_types + else: + res = target - # Write the result back to table - entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE) + # Write the result back to table + entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE) # Generate metadata result res = [] diff --git a/src/snowflake/snowpark/mock/_stage_registry.py b/src/snowflake/snowpark/mock/_stage_registry.py index 7ed55d1cdc6..d4100606821 100644 --- a/src/snowflake/snowpark/mock/_stage_registry.py +++ b/src/snowflake/snowpark/mock/_stage_registry.py @@ -647,30 +647,34 @@ def __init__(self, conn: "MockServerConnection") -> None: self._root_dir = tempfile.TemporaryDirectory() self._stage_registry = {} self._conn = conn + self._lock = conn.get_lock() def create_or_replace_stage(self, stage_name): - self._stage_registry[stage_name] = StageEntity( - self._root_dir.name, stage_name, self._conn - ) + with self._lock: + self._stage_registry[stage_name] = StageEntity( + self._root_dir.name, stage_name, self._conn + ) def __getitem__(self, stage_name: str): # the assumption here is that stage always exists - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - return self._stage_registry[stage_name] + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + return self._stage_registry[stage_name] def put( self, local_file_name: str, stage_location: str, overwrite: bool = False ) -> TableEmulator: stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) # the assumption here is that stage always exists - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - return self._stage_registry[stage_name].put_file( - local_file_name=local_file_name, - stage_prefix=stage_prefix, - overwrite=overwrite, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + return self._stage_registry[stage_name].put_file( + local_file_name=local_file_name, + stage_prefix=stage_prefix, + overwrite=overwrite, + ) def upload_stream( self, @@ -681,14 +685,15 @@ def upload_stream( ) -> Dict: stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) # the assumption here is that stage always exists - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - return self._stage_registry[stage_name].upload_stream( - input_stream=input_stream, - stage_prefix=stage_prefix, - file_name=file_name, - overwrite=overwrite, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + return self._stage_registry[stage_name].upload_stream( + input_stream=input_stream, + stage_prefix=stage_prefix, + file_name=file_name, + overwrite=overwrite, + ) def get( self, @@ -701,14 +706,15 @@ def get( f"Invalid stage {stage_location}, stage name should start with character '@'" ) stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - - return self._stage_registry[stage_name].get_file( - stage_location=stage_prefix, - target_directory=target_directory, - options=options, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + + return self._stage_registry[stage_name].get_file( + stage_location=stage_prefix, + target_directory=target_directory, + options=options, + ) def read_file( self, @@ -723,13 +729,14 @@ def read_file( f"Invalid stage {stage_location}, stage name should start with character '@'" ) stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - - return self._stage_registry[stage_name].read_file( - stage_location=stage_prefix, - format=format, - schema=schema, - analyzer=analyzer, - options=options, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + + return self._stage_registry[stage_name].read_file( + stage_location=stage_prefix, + format=format, + schema=schema, + analyzer=analyzer, + options=options, + ) diff --git a/src/snowflake/snowpark/mock/_stored_procedure.py b/src/snowflake/snowpark/mock/_stored_procedure.py index d93500da2e8..14abec358c2 100644 --- a/src/snowflake/snowpark/mock/_stored_procedure.py +++ b/src/snowflake/snowpark/mock/_stored_procedure.py @@ -154,9 +154,11 @@ def __init__(self, *args, **kwargs) -> None: ) # maps name to either the callable or a pair of str (module_name, callable_name) self._sproc_level_imports = dict() # maps name to a set of file paths self._session_level_imports = set() + self._lock = self._session._conn.get_lock() def _clear_session_imports(self): - self._session_level_imports.clear() + with self._lock: + self._session_level_imports.clear() def _import_file( self, @@ -172,16 +174,17 @@ def _import_file( imports specified. """ - absolute_module_path, module_name = extract_import_dir_and_module_name( - file_path, self._session._conn.stage_registry, import_path - ) + with self._lock: + absolute_module_path, module_name = extract_import_dir_and_module_name( + file_path, self._session._conn.stage_registry, import_path + ) - if sproc_name: - self._sproc_level_imports[sproc_name].add(absolute_module_path) - else: - self._session_level_imports.add(absolute_module_path) + if sproc_name: + self._sproc_level_imports[sproc_name].add(absolute_module_path) + else: + self._session_level_imports.add(absolute_module_path) - return module_name + return module_name def _do_register_sp( self, @@ -224,90 +227,96 @@ def _do_register_sp( error_message="Registering anonymous sproc is not currently supported.", raise_error=NotImplementedError, ) - ( - sproc_name, - is_pandas_udf, - is_dataframe_input, - return_type, - input_types, - opt_arg_defaults, - ) = process_registration_inputs( - self._session, - TempObjectType.PROCEDURE, - func, - return_type, - input_types, - sp_name, - anonymous, - ) - current_schema = self._session.get_current_schema() - current_database = self._session.get_current_database() - sproc_name = get_fully_qualified_name( - sproc_name, current_schema, current_database - ) - - check_python_runtime_version(self._session._runtime_version_from_requirement) - - if replace and if_not_exists: - raise ValueError("options replace and if_not_exists are incompatible") + with self._lock: + ( + sproc_name, + is_pandas_udf, + is_dataframe_input, + return_type, + input_types, + opt_arg_defaults, + ) = process_registration_inputs( + self._session, + TempObjectType.PROCEDURE, + func, + return_type, + input_types, + sp_name, + anonymous, + ) - if sproc_name in self._registry and if_not_exists: - return self._registry[sproc_name] + current_schema = self._session.get_current_schema() + current_database = self._session.get_current_database() + sproc_name = get_fully_qualified_name( + sproc_name, current_schema, current_database + ) - if sproc_name in self._registry and not replace: - raise SnowparkLocalTestingException( - f"002002 (42710): SQL compilation error: \nObject '{sproc_name}' already exists.", - error_code="1304", + check_python_runtime_version( + self._session._runtime_version_from_requirement ) - if is_pandas_udf: - raise TypeError("pandas stored procedure is not supported") + if replace and if_not_exists: + raise ValueError("options replace and if_not_exists are incompatible") - if packages: - pass # NO-OP + if sproc_name in self._registry and if_not_exists: + return self._registry[sproc_name] - if imports is not None or type(func) is tuple: - self._sproc_level_imports[sproc_name] = set() + if sproc_name in self._registry and not replace: + raise SnowparkLocalTestingException( + f"002002 (42710): SQL compilation error: \nObject '{sproc_name}' already exists.", + error_code="1304", + ) - if imports is not None: - for _import in imports: - if isinstance(_import, str): - self._import_file(_import, sproc_name=sproc_name) - elif isinstance(_import, tuple) and all( - isinstance(item, str) for item in _import - ): - local_path, import_path = _import - self._import_file(local_path, import_path, sproc_name=sproc_name) - else: - raise TypeError( - "stored-proc-level import can only be a file path (str) or a tuple of the file path (str) and the import path (str)" - ) + if is_pandas_udf: + raise TypeError("pandas stored procedure is not supported") - if type(func) is tuple: # register from file - if sproc_name not in self._sproc_level_imports: - self._sproc_level_imports[sproc_name] = set() - module_name = self._import_file(func[0], sproc_name=sproc_name) - func = (module_name, func[1]) + if packages: + pass # NO-OP - if sproc_name in self._sproc_level_imports: - sproc_imports = self._sproc_level_imports[sproc_name] - else: - sproc_imports = copy(self._session_level_imports) + if imports is not None or type(func) is tuple: + self._sproc_level_imports[sproc_name] = set() - sproc = MockStoredProcedure( - func, - return_type, - input_types, - sproc_name, - sproc_imports, - execute_as=execute_as, - strict=strict, - ) + if imports is not None: + for _import in imports: + if isinstance(_import, str): + self._import_file(_import, sproc_name=sproc_name) + elif isinstance(_import, tuple) and all( + isinstance(item, str) for item in _import + ): + local_path, import_path = _import + self._import_file( + local_path, import_path, sproc_name=sproc_name + ) + else: + raise TypeError( + "stored-proc-level import can only be a file path (str) or a tuple of the file path (str) and the import path (str)" + ) + + if type(func) is tuple: # register from file + if sproc_name not in self._sproc_level_imports: + self._sproc_level_imports[sproc_name] = set() + module_name = self._import_file(func[0], sproc_name=sproc_name) + func = (module_name, func[1]) + + if sproc_name in self._sproc_level_imports: + sproc_imports = self._sproc_level_imports[sproc_name] + else: + sproc_imports = copy(self._session_level_imports) + + sproc = MockStoredProcedure( + func, + return_type, + input_types, + sproc_name, + sproc_imports, + execute_as=execute_as, + strict=strict, + ) - self._registry[sproc_name] = sproc + self._registry[sproc_name] = sproc - return sproc + return sproc def call( self, @@ -316,17 +325,18 @@ def call( session: Optional["snowflake.snowpark.session.Session"] = None, statement_params: Optional[Dict[str, str]] = None, ): - current_schema = self._session.get_current_schema() - current_database = self._session.get_current_database() - sproc_name = get_fully_qualified_name( - sproc_name, current_schema, current_database - ) - - if sproc_name not in self._registry: - raise SnowparkLocalTestingException( - f"Unknown function {sproc_name}. Stored procedure by that name does not exist." + with self._lock: + current_schema = self._session.get_current_schema() + current_database = self._session.get_current_database() + sproc_name = get_fully_qualified_name( + sproc_name, current_schema, current_database ) - return self._registry[sproc_name]( - *args, session=session, statement_params=statement_params - ) + if sproc_name not in self._registry: + raise SnowparkLocalTestingException( + f"Unknown function {sproc_name}. Stored procedure by that name does not exist." + ) + + sproc = self._registry[sproc_name] + + return sproc(*args, session=session, statement_params=statement_params) diff --git a/src/snowflake/snowpark/mock/_telemetry.py b/src/snowflake/snowpark/mock/_telemetry.py index 857291b47fd..6e4273aa7ff 100644 --- a/src/snowflake/snowpark/mock/_telemetry.py +++ b/src/snowflake/snowpark/mock/_telemetry.py @@ -5,6 +5,7 @@ import json import logging import os +import threading import uuid from datetime import datetime from enum import Enum @@ -92,6 +93,7 @@ def __init__(self) -> None: ) self._deployment_url = self.PROD self._enable = True + self._lock = threading.RLock() def _upload_payload(self, payload) -> None: if not REQUESTS_AVAILABLE: @@ -136,12 +138,25 @@ def add(self, event) -> None: if not self.enabled: return - self.queue.put(event) - if self.queue.qsize() > self.batch_size: - payload = self.export_queue_to_string() - if payload is None: - return - self._upload_payload(payload) + with self._lock: + self.queue.put(event) + if self.queue.qsize() > self.batch_size: + payload = self.export_queue_to_string() + if payload is None: + return + self._upload_payload(payload) + + def flush(self) -> None: + """Flushes all telemetry events in the queue and submit them to the back-end.""" + if not self.enabled: + return + + with self._lock: + if not self.queue.empty(): + payload = self.export_queue_to_string() + if payload is None: + return + self._upload_payload(payload) @property def enabled(self) -> bool: @@ -158,8 +173,9 @@ def disable(self) -> None: def export_queue_to_string(self): logs = list() - while not self.queue.empty(): - logs.append(self.queue.get()) + with self._lock: + while not self.queue.empty(): + logs.append(self.queue.get()) # We may get an exception trying to serialize a python object to JSON try: payload = json.dumps(logs) diff --git a/src/snowflake/snowpark/mock/_udf.py b/src/snowflake/snowpark/mock/_udf.py index 7cedf0de660..a7a17d9a030 100644 --- a/src/snowflake/snowpark/mock/_udf.py +++ b/src/snowflake/snowpark/mock/_udf.py @@ -38,9 +38,11 @@ def __init__(self, *args, **kwargs) -> None: dict() ) # maps udf name to either the callable or a pair of str (module_name, callable_name) self._session_level_imports = set() + self._lock = self._session._conn.get_lock() def _clear_session_imports(self): - self._session_level_imports.clear() + with self._lock: + self._session_level_imports.clear() def _import_file( self, @@ -54,29 +56,32 @@ def _import_file( When udf_name is not None, the import is added to the UDF associated with the name; Otherwise, it is a session level import and will be used if no UDF-level imports are specified. """ - absolute_module_path, module_name = extract_import_dir_and_module_name( - file_path, self._session._conn.stage_registry, import_path - ) - if udf_name: - self._registry[udf_name].add_import(absolute_module_path) - else: - self._session_level_imports.add(absolute_module_path) + with self._lock: + absolute_module_path, module_name = extract_import_dir_and_module_name( + file_path, self._session._conn.stage_registry, import_path + ) + if udf_name: + self._registry[udf_name].add_import(absolute_module_path) + else: + self._session_level_imports.add(absolute_module_path) - return module_name + return module_name def get_udf(self, udf_name: str) -> MockUserDefinedFunction: - if udf_name not in self._registry: - raise SnowparkLocalTestingException(f"udf {udf_name} does not exist.") - return self._registry[udf_name] + with self._lock: + if udf_name not in self._registry: + raise SnowparkLocalTestingException(f"udf {udf_name} does not exist.") + return self._registry[udf_name] def get_udf_imports(self, udf_name: str) -> Set[str]: - udf = self._registry.get(udf_name) - if not udf: - return set() - elif udf.use_session_imports: - return self._session_level_imports - else: - return udf._imports + with self._lock: + udf = self._registry.get(udf_name) + if not udf: + return set() + elif udf.use_session_imports: + return self._session_level_imports + else: + return udf._imports def _do_register_udf( self, @@ -113,73 +118,81 @@ def _do_register_udf( raise_error=NotImplementedError, ) - # get the udf name, return and input types - ( - udf_name, - is_pandas_udf, - is_dataframe_input, - return_type, - input_types, - opt_arg_defaults, - ) = process_registration_inputs( - self._session, TempObjectType.FUNCTION, func, return_type, input_types, name - ) - - current_schema = self._session.get_current_schema() - current_database = self._session.get_current_database() - udf_name = get_fully_qualified_name(udf_name, current_schema, current_database) - - # allow registering pandas UDF from udf(), - # but not allow registering non-pandas UDF from pandas_udf() - if from_pandas_udf_function and not is_pandas_udf: - raise ValueError( - "You cannot create a non-vectorized UDF using pandas_udf(). " - "Use udf() instead." + with self._lock: + # get the udf name, return and input types + ( + udf_name, + is_pandas_udf, + is_dataframe_input, + return_type, + input_types, + opt_arg_defaults, + ) = process_registration_inputs( + self._session, + TempObjectType.FUNCTION, + func, + return_type, + input_types, + name, ) - custom_python_runtime_version_allowed = False + current_schema = self._session.get_current_schema() + current_database = self._session.get_current_database() + udf_name = get_fully_qualified_name( + udf_name, current_schema, current_database + ) - if not custom_python_runtime_version_allowed: - check_python_runtime_version( - self._session._runtime_version_from_requirement + # allow registering pandas UDF from udf(), + # but not allow registering non-pandas UDF from pandas_udf() + if from_pandas_udf_function and not is_pandas_udf: + raise ValueError( + "You cannot create a non-vectorized UDF using pandas_udf(). " + "Use udf() instead." + ) + + custom_python_runtime_version_allowed = False + + if not custom_python_runtime_version_allowed: + check_python_runtime_version( + self._session._runtime_version_from_requirement + ) + + if replace and if_not_exists: + raise ValueError("options replace and if_not_exists are incompatible") + + if udf_name in self._registry and if_not_exists: + return self._registry[udf_name] + + if udf_name in self._registry and not replace: + raise SnowparkSQLException( + f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.", + error_code="1304", + ) + + if packages: + pass # NO-OP + + # register + self._registry[udf_name] = MockUserDefinedFunction( + func, + return_type, + input_types, + udf_name, + strict=strict, + packages=packages, + use_session_imports=imports is None, ) - if replace and if_not_exists: - raise ValueError("options replace and if_not_exists are incompatible") + if type(func) is tuple: # update file registration + module_name = self._import_file(func[0], udf_name=udf_name) + self._registry[udf_name].func = (module_name, func[1]) - if udf_name in self._registry and if_not_exists: - return self._registry[udf_name] + if imports is not None: + for _import in imports: + if type(_import) is str: + self._import_file(_import, udf_name=udf_name) + else: + local_path, import_path = _import + self._import_file(local_path, import_path, udf_name=udf_name) - if udf_name in self._registry and not replace: - raise SnowparkSQLException( - f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.", - error_code="1304", - ) - - if packages: - pass # NO-OP - - # register - self._registry[udf_name] = MockUserDefinedFunction( - func, - return_type, - input_types, - udf_name, - strict=strict, - packages=packages, - use_session_imports=imports is None, - ) - - if type(func) is tuple: # update file registration - module_name = self._import_file(func[0], udf_name=udf_name) - self._registry[udf_name].func = (module_name, func[1]) - - if imports is not None: - for _import in imports: - if type(_import) is str: - self._import_file(_import, udf_name=udf_name) - else: - local_path, import_path = _import - self._import_file(local_path, import_path, udf_name=udf_name) - - return self._registry[udf_name] + return self._registry[udf_name] diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 6960d0eb629..8f9834630b7 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -88,13 +88,13 @@ # TODO: SNOW-851745 make sure add all Snowpark pandas API general functions from modin.pandas import plotting # type: ignore[import] +from modin.pandas.dataframe import DataFrame from modin.pandas.series import Series from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, ) -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.general import ( bdate_range, concat, @@ -185,10 +185,8 @@ modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP) -# For any method defined on Series/DF, add telemetry to it if it: -# 1. Is defined directly on an upstream class -# 2. The method name does not start with an _, or is in TELEMETRY_PRIVATE_METHODS - +# For any method defined on Series/DF, add telemetry to it if the method name does not start with an +# _, or the method is in TELEMETRY_PRIVATE_METHODS. This includes methods defined as an extension/override. for attr_name in dir(Series): # Since Series is defined in upstream Modin, all of its members were either defined upstream # or overridden by extension. @@ -197,11 +195,9 @@ try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name)) ) - -# TODO: SNOW-1063346 -# Since we still use the vendored version of DataFrame and the overrides for the top-level -# namespace haven't been performed yet, we need to set properties on the vendored version for attr_name in dir(DataFrame): + # Since DataFrame is defined in upstream Modin, all of its members were either defined upstream + # or overridden by extension. if not attr_name.startswith("_") or attr_name in TELEMETRY_PRIVATE_METHODS: register_dataframe_accessor(attr_name)( try_add_telemetry_to_attribute(attr_name, getattr(DataFrame, attr_name)) diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py index 6a34f50e42a..47d44835fe4 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py @@ -19,9 +19,12 @@ # existing code originally distributed by the Modin project, under the Apache License, # Version 2.0. -from modin.pandas.api.extensions import register_series_accessor +from modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) -from .extensions import register_dataframe_accessor, register_pd_accessor +from .extensions import register_pd_accessor __all__ = [ "register_dataframe_accessor", diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py index 45896292e74..05424c92072 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py @@ -86,49 +86,6 @@ def decorator(new_attr: Any): return decorator -def register_dataframe_accessor(name: str): - """ - Registers a dataframe attribute with the name provided. - This is a decorator that assigns a new attribute to DataFrame. It can be used - with the following syntax: - ``` - @register_dataframe_accessor("new_method") - def my_new_dataframe_method(*args, **kwargs): - # logic goes here - return - ``` - The new attribute can then be accessed with the name provided: - ``` - df.new_method(*my_args, **my_kwargs) - ``` - - If you want a property accessor, you must annotate with @property - after the call to this function: - ``` - @register_dataframe_accessor("new_prop") - @property - def my_new_dataframe_property(*args, **kwargs): - return _prop - ``` - - Parameters - ---------- - name : str - The name of the attribute to assign to DataFrame. - Returns - ------- - decorator - Returns the decorator function. - """ - import snowflake.snowpark.modin.pandas as pd - - return _set_attribute_on_obj( - name, - pd.dataframe._DATAFRAME_EXTENSIONS_, - pd.dataframe.DataFrame, - ) - - def register_pd_accessor(name: str): """ Registers a pd namespace attribute with the name provided. diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py deleted file mode 100644 index 83893e83e9c..00000000000 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ /dev/null @@ -1,3511 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -# Licensed to Modin Development Team under one or more contributor license agreements. -# See the NOTICE file distributed with this work for additional information regarding -# copyright ownership. The Modin Development Team licenses this file to you under the -# Apache License, Version 2.0 (the "License"); you may not use this file except in -# compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under -# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific language -# governing permissions and limitations under the License. - -# Code in this file may constitute partial or total reimplementation, or modification of -# existing code originally distributed by the Modin project, under the Apache License, -# Version 2.0. - -"""Module houses ``DataFrame`` class, that is distributed version of ``pandas.DataFrame``.""" - -from __future__ import annotations - -import collections -import datetime -import functools -import itertools -import re -import sys -import warnings -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from logging import getLogger -from typing import IO, Any, Callable, Literal - -import numpy as np -import pandas -from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor -from modin.pandas.base import BasePandasDataset - -# from . import _update_engine -from modin.pandas.iterator import PartitionIterator -from modin.pandas.series import Series -from pandas._libs.lib import NoDefault, no_default -from pandas._typing import ( - AggFuncType, - AnyArrayLike, - Axes, - Axis, - CompressionOptions, - FilePath, - FillnaOptions, - IgnoreRaise, - IndexLabel, - Level, - PythonFuncType, - Renamer, - Scalar, - StorageOptions, - Suffixes, - WriteBuffer, -) -from pandas.core.common import apply_if_callable, is_bool_indexer -from pandas.core.dtypes.common import ( - infer_dtype_from_object, - is_bool_dtype, - is_dict_like, - is_list_like, - is_numeric_dtype, -) -from pandas.core.dtypes.inference import is_hashable, is_integer -from pandas.core.indexes.frozen import FrozenList -from pandas.io.formats.printing import pprint_thing -from pandas.util._validators import validate_bool_kwarg - -from snowflake.snowpark.modin import pandas as pd -from snowflake.snowpark.modin.pandas.groupby import ( - DataFrameGroupBy, - validate_groupby_args, -) -from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( - SnowparkPandasRowPartitionIterator, -) -from snowflake.snowpark.modin.pandas.utils import ( - create_empty_native_pandas_frame, - from_non_pandas, - from_pandas, - is_scalar, - raise_if_native_pandas_objects, - replace_external_data_keys_with_empty_pandas_series, - replace_external_data_keys_with_query_compiler, -) -from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated -from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike -from snowflake.snowpark.modin.plugin.utils.error_message import ( - ErrorMessage, - dataframe_not_implemented, -) -from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP -from snowflake.snowpark.modin.plugin.utils.warning_message import ( - SET_DATAFRAME_ATTRIBUTE_WARNING, - WarningMessage, -) -from snowflake.snowpark.modin.utils import _inherit_docstrings, hashable, to_pandas -from snowflake.snowpark.udf import UserDefinedFunction - -logger = getLogger(__name__) - -DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( - "Currently do not support Series or list-like keys with range-like values" -) - -DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( - "Currently do not support assigning a slice value as if it's a scalar value" -) - -DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( - "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " - "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " - "can work on the entire DataFrame in one shot." -) - -# Dictionary of extensions assigned to this class -_DATAFRAME_EXTENSIONS_ = {} - - -@_inherit_docstrings( - pandas.DataFrame, - excluded=[ - pandas.DataFrame.flags, - pandas.DataFrame.cov, - pandas.DataFrame.merge, - pandas.DataFrame.reindex, - pandas.DataFrame.to_parquet, - pandas.DataFrame.fillna, - ], - apilink="pandas.DataFrame", -) -class DataFrame(BasePandasDataset): - _pandas_class = pandas.DataFrame - - def __init__( - self, - data=None, - index=None, - columns=None, - dtype=None, - copy=None, - query_compiler=None, - ) -> None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Siblings are other dataframes that share the same query compiler. We - # use this list to update inplace when there is a shallow copy. - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native - - self._siblings = [] - - # Engine.subscribe(_update_engine) - if isinstance(data, (DataFrame, Series)): - self._query_compiler = data._query_compiler.copy() - if index is not None and any(i not in data.index for i in index): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if isinstance(data, Series): - # We set the column name if it is not in the provided Series - if data.name is None: - self.columns = [0] if columns is None else columns - # If the columns provided are not in the named Series, pandas clears - # the DataFrame and sets columns to the columns provided. - elif columns is not None and data.name not in columns: - self._query_compiler = from_pandas( - self.__constructor__(columns=columns) - )._query_compiler - if index is not None: - self._query_compiler = data.loc[index]._query_compiler - elif columns is None and index is None: - data._add_sibling(self) - else: - if columns is not None and any(i not in data.columns for i in columns): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if index is None: - index = slice(None) - if columns is None: - columns = slice(None) - self._query_compiler = data.loc[index, columns]._query_compiler - - # Check type of data and use appropriate constructor - elif query_compiler is None: - distributed_frame = from_non_pandas(data, index, columns, dtype) - if distributed_frame is not None: - self._query_compiler = distributed_frame._query_compiler - return - - if isinstance(data, pandas.Index): - pass - elif is_list_like(data) and not is_dict_like(data): - old_dtype = getattr(data, "dtype", None) - values = [ - obj._to_pandas() if isinstance(obj, Series) else obj for obj in data - ] - if isinstance(data, np.ndarray): - data = np.array(values, dtype=old_dtype) - else: - try: - data = type(data)(values, dtype=old_dtype) - except TypeError: - data = values - elif is_dict_like(data) and not isinstance( - data, (pandas.Series, Series, pandas.DataFrame, DataFrame) - ): - if columns is not None: - data = {key: value for key, value in data.items() if key in columns} - - if len(data) and all(isinstance(v, Series) for v in data.values()): - from .general import concat - - new_qc = concat( - data.values(), axis=1, keys=data.keys() - )._query_compiler - - if dtype is not None: - new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) - if index is not None: - new_qc = new_qc.reindex( - axis=0, labels=try_convert_index_to_native(index) - ) - if columns is not None: - new_qc = new_qc.reindex( - axis=1, labels=try_convert_index_to_native(columns) - ) - - self._query_compiler = new_qc - return - - data = { - k: v._to_pandas() if isinstance(v, Series) else v - for k, v in data.items() - } - pandas_df = pandas.DataFrame( - data=try_convert_index_to_native(data), - index=try_convert_index_to_native(index), - columns=try_convert_index_to_native(columns), - dtype=dtype, - copy=copy, - ) - self._query_compiler = from_pandas(pandas_df)._query_compiler - else: - self._query_compiler = query_compiler - - def __repr__(self): - """ - Return a string representation for a particular ``DataFrame``. - - Returns - ------- - str - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - num_rows = pandas.get_option("display.max_rows") or len(self) - # see _repr_html_ for comment, allow here also all column behavior - num_cols = pandas.get_option("display.max_columns") or len(self.columns) - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") - result = repr(repr_df) - - # if truncated, add shape information - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # The split here is so that we don't repr pandas row lengths. - return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( - row_count, col_count - ) - else: - return result - - def _repr_html_(self): # pragma: no cover - """ - Return a html representation for a particular ``DataFrame``. - - Returns - ------- - str - - Notes - ----- - Supports pandas `display.max_rows` and `display.max_columns` options. - """ - num_rows = pandas.get_option("display.max_rows") or 60 - # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow - # here value=0 which means display all columns. - num_cols = pandas.get_option("display.max_columns") - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols) - result = repr_df._repr_html_() - - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # We split so that we insert our correct dataframe dimensions. - return ( - result.split("

")[0] - + f"

{row_count} rows × {col_count} columns

\n" - ) - else: - return result - - def _get_columns(self) -> pandas.Index: - """ - Get the columns for this Snowpark pandas ``DataFrame``. - - Returns - ------- - Index - The all columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.columns - - def _set_columns(self, new_columns: Axes) -> None: - """ - Set the columns for this Snowpark pandas ``DataFrame``. - - Parameters - ---------- - new_columns : - The new columns to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_inplace( - new_query_compiler=self._query_compiler.set_columns(new_columns) - ) - - columns = property(_get_columns, _set_columns) - - @property - def ndim(self) -> int: - return 2 - - def drop_duplicates( - self, subset=None, keep="first", inplace=False, ignore_index=False - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - """ - Return ``DataFrame`` with duplicate rows removed. - """ - return super().drop_duplicates( - subset=subset, keep=keep, inplace=inplace, ignore_index=ignore_index - ) - - def dropna( - self, - *, - axis: Axis = 0, - how: str | NoDefault = no_default, - thresh: int | NoDefault = no_default, - subset: IndexLabel = None, - inplace: bool = False, - ): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super()._dropna( - axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace - ) - - @property - def dtypes(self): # noqa: RT01, D200 - """ - Return the dtypes in the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.dtypes - - def duplicated( - self, subset: Hashable | Sequence[Hashable] = None, keep: DropKeep = "first" - ): - """ - Return boolean ``Series`` denoting duplicate rows. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - df = self[subset] if subset is not None else self - new_qc = df._query_compiler.duplicated(keep=keep) - duplicates = self._reduce_dimension(new_qc) - # remove Series name which was assigned automatically by .apply in QC - # this is pandas behavior, i.e., if duplicated result is a series, no name is returned - duplicates.name = None - return duplicates - - @property - def empty(self) -> bool: - """ - Indicate whether ``DataFrame`` is empty. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self.columns) == 0 or len(self) == 0 - - @property - def axes(self): - """ - Return a list representing the axes of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return [self.index, self.columns] - - @property - def shape(self) -> tuple[int, int]: - """ - Return a tuple representing the dimensionality of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self), len(self.columns) - - def add_prefix(self, prefix): - """ - Prefix labels with string `prefix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string prefix values into str and adds it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(prefix), substring_type="prefix", axis=1 - ) - ) - - def add_suffix(self, suffix): - """ - Suffix labels with string `suffix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string suffix values into str and appends it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(suffix), substring_type="suffix", axis=1 - ) - ) - - @dataframe_not_implemented() - def map( - self, func, na_action: str | None = None, **kwargs - ) -> DataFrame: # pragma: no cover - if not callable(func): - raise ValueError(f"'{type(func)}' object is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.map(func, na_action=na_action, **kwargs) - ) - - def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not callable(func): - raise TypeError(f"{func} is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.applymap( - func, na_action=na_action, **kwargs - ) - ) - - def aggregate( - self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().aggregate(func, axis, *args, **kwargs) - - agg = aggregate - - def apply( - self, - func: AggFuncType | UserDefinedFunction, - axis: Axis = 0, - raw: bool = False, - result_type: Literal["expand", "reduce", "broadcast"] | None = None, - args=(), - **kwargs, - ): - """ - Apply a function along an axis of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - query_compiler = self._query_compiler.apply( - func, - axis, - raw=raw, - result_type=result_type, - args=args, - **kwargs, - ) - if not isinstance(query_compiler, type(self._query_compiler)): - # A scalar was returned - return query_compiler - - # If True, it is an unamed series. - # Theoretically, if df.apply returns a Series, it will only be an unnamed series - # because the function is supposed to be series -> scalar. - if query_compiler._modin_frame.is_unnamed_series(): - return Series(query_compiler=query_compiler) - else: - return self.__constructor__(query_compiler=query_compiler) - - def groupby( - self, - by=None, - axis: Axis | NoDefault = no_default, - level: IndexLabel | None = None, - as_index: bool = True, - sort: bool = True, - group_keys: bool = True, - observed: bool | NoDefault = no_default, - dropna: bool = True, - ): - """ - Group ``DataFrame`` using a mapper or by a ``Series`` of columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if axis is not no_default: - axis = self._get_axis_number(axis) - if axis == 1: - warnings.warn( - "DataFrame.groupby with axis=1 is deprecated. Do " - + "`frame.T.groupby(...)` without axis instead.", - FutureWarning, - stacklevel=1, - ) - else: - warnings.warn( - "The 'axis' keyword in DataFrame.groupby is deprecated and " - + "will be removed in a future version.", - FutureWarning, - stacklevel=1, - ) - else: - axis = 0 - - validate_groupby_args(by, level, observed) - - axis = self._get_axis_number(axis) - - if axis != 0 and as_index is False: - raise ValueError("as_index=False only valid for axis=0") - - idx_name = None - - if ( - not isinstance(by, Series) - and is_list_like(by) - and len(by) == 1 - # if by is a list-like of (None,), we have to keep it as a list because - # None may be referencing a column or index level whose label is - # `None`, and by=None wold mean that there is no `by` param. - and by[0] is not None - ): - by = by[0] - - if hashable(by) and ( - not callable(by) and not isinstance(by, (pandas.Grouper, FrozenList)) - ): - idx_name = by - elif isinstance(by, Series): - idx_name = by.name - if by._parent is self: - # if the SnowSeries comes from the current dataframe, - # convert it to labels directly for easy processing - by = by.name - elif is_list_like(by): - if axis == 0 and all( - ( - (hashable(o) and (o in self)) - or isinstance(o, Series) - or (is_list_like(o) and len(o) == len(self.shape[axis])) - ) - for o in by - ): - # plit 'by's into those that belongs to the self (internal_by) - # and those that doesn't (external_by). For SnowSeries that belongs - # to current DataFrame, we convert it to labels for easy process. - internal_by, external_by = [], [] - - for current_by in by: - if hashable(current_by): - internal_by.append(current_by) - elif isinstance(current_by, Series): - if current_by._parent is self: - internal_by.append(current_by.name) - else: - external_by.append(current_by) # pragma: no cover - else: - external_by.append(current_by) - - by = internal_by + external_by - - return DataFrameGroupBy( - self, - by, - axis, - level, - as_index, - sort, - group_keys, - idx_name, - observed=observed, - dropna=dropna, - ) - - def keys(self): # noqa: RT01, D200 - """ - Get columns of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns - - def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 - """ - Transpose index and columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy: - WarningMessage.ignored_argument( - operation="transpose", - argument="copy", - message="Transpose ignore copy argument in Snowpark pandas API", - ) - - if args: - WarningMessage.ignored_argument( - operation="transpose", - argument="args", - message="Transpose ignores args in Snowpark pandas API", - ) - - return self.__constructor__(query_compiler=self._query_compiler.transpose()) - - T = property(transpose) - - def add( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `add`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "add", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def assign(self, **kwargs): # noqa: PR01, RT01, D200 - """ - Assign new columns to a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - df = self.copy() - for k, v in kwargs.items(): - if callable(v): - df[k] = v(df) - else: - df[k] = v - return df - - @dataframe_not_implemented() - def boxplot( - self, - column=None, - by=None, - ax=None, - fontsize=None, - rot=0, - grid=True, - figsize=None, - layout=None, - return_type=None, - backend=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make a box plot from ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return to_pandas(self).boxplot( - column=column, - by=by, - ax=ax, - fontsize=fontsize, - rot=rot, - grid=grid, - figsize=figsize, - layout=layout, - return_type=return_type, - backend=backend, - **kwargs, - ) - - @dataframe_not_implemented() - def combine( - self, other, func, fill_value=None, overwrite=True - ): # noqa: PR01, RT01, D200 - """ - Perform column-wise combine with another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().combine(other, func, fill_value=fill_value, overwrite=overwrite) - - def compare( - self, - other, - align_axis=1, - keep_shape: bool = False, - keep_equal: bool = False, - result_names=("self", "other"), - ) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Compare to another ``DataFrame`` and show the differences. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - raise TypeError(f"Cannot compare DataFrame to {type(other)}") - other = self._validate_other(other, 0, compare_index=True) - return self.__constructor__( - query_compiler=self._query_compiler.compare( - other, - align_axis=align_axis, - keep_shape=keep_shape, - keep_equal=keep_equal, - result_names=result_names, - ) - ) - - def corr( - self, - method: str | Callable = "pearson", - min_periods: int | None = None, - numeric_only: bool = False, - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - corr_df = self - if numeric_only: - corr_df = self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - return self.__constructor__( - query_compiler=corr_df._query_compiler.corr( - method=method, - min_periods=min_periods, - ) - ) - - @dataframe_not_implemented() - def corrwith( - self, other, axis=0, drop=False, method="pearson", numeric_only=False - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, DataFrame): - other = other._query_compiler.to_pandas() - return self._default_to_pandas( - pandas.DataFrame.corrwith, - other, - axis=axis, - drop=drop, - method=method, - numeric_only=numeric_only, - ) - - @dataframe_not_implemented() - def cov( - self, - min_periods: int | None = None, - ddof: int | None = 1, - numeric_only: bool = False, - ): - """ - Compute pairwise covariance of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.cov( - min_periods=min_periods, - ddof=ddof, - numeric_only=numeric_only, - ) - ) - - @dataframe_not_implemented() - def dot(self, other): # noqa: PR01, RT01, D200 - """ - Compute the matrix multiplication between the ``DataFrame`` and `other`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if isinstance(other, BasePandasDataset): - common = self.columns.union(other.index) - if len(common) > len(self.columns) or len(common) > len( - other - ): # pragma: no cover - raise ValueError("Matrices are not aligned") - - if isinstance(other, DataFrame): - return self.__constructor__( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - else: - return self._reduce_dimension( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - - other = np.asarray(other) - if self.shape[1] != other.shape[0]: - raise ValueError( - f"Dot product shape mismatch, {self.shape} vs {other.shape}" - ) - - if len(other.shape) > 1: - return self.__constructor__( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - return self._reduce_dimension( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Perform equality comparison of ``DataFrame`` and `other` (binary operator `eq`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("eq", other, axis=axis, level=level) - - def equals(self, other) -> bool: # noqa: PR01, RT01, D200 - """ - Test whether two objects contain the same elements. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, pandas.DataFrame): - # Copy into a Modin DataFrame to simplify logic below - other = self.__constructor__(other) - - if ( - type(self) is not type(other) - or not self.index.equals(other.index) - or not self.columns.equals(other.columns) - ): - return False - - result = self.__constructor__( - query_compiler=self._query_compiler.equals(other._query_compiler) - ) - return result.all(axis=None) - - def _update_var_dicts_in_kwargs(self, expr, kwargs): - """ - Copy variables with "@" prefix in `local_dict` and `global_dict` keys of kwargs. - - Parameters - ---------- - expr : str - The expression string to search variables with "@" prefix. - kwargs : dict - See the documentation for eval() for complete details on the keyword arguments accepted by query(). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if "@" not in expr: - return - frame = sys._getframe() - try: - f_locals = frame.f_back.f_back.f_back.f_back.f_locals - f_globals = frame.f_back.f_back.f_back.f_back.f_globals - finally: - del frame - local_names = set(re.findall(r"@([\w]+)", expr)) - local_dict = {} - global_dict = {} - - for name in local_names: - for dct_out, dct_in in ((local_dict, f_locals), (global_dict, f_globals)): - try: - dct_out[name] = dct_in[name] - except KeyError: - pass - - if local_dict: - local_dict.update(kwargs.get("local_dict") or {}) - kwargs["local_dict"] = local_dict - if global_dict: - global_dict.update(kwargs.get("global_dict") or {}) - kwargs["global_dict"] = global_dict - - @dataframe_not_implemented() - def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Evaluate a string describing operations on ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - self._update_var_dicts_in_kwargs(expr, kwargs) - new_query_compiler = self._query_compiler.eval(expr, **kwargs) - return_type = type( - pandas.DataFrame(columns=self.columns) - .astype(self.dtypes) - .eval(expr, **kwargs) - ).__name__ - if return_type == type(self).__name__: - return self._create_or_update_from_compiler(new_query_compiler, inplace) - else: - if inplace: - raise ValueError("Cannot operate inplace if there is no assignment") - return getattr(sys.modules[self.__module__], return_type)( - query_compiler=new_query_compiler - ) - - def fillna( - self, - value: Hashable | Mapping | Series | DataFrame = None, - *, - method: FillnaOptions | None = None, - axis: Axis | None = None, - inplace: bool = False, - limit: int | None = None, - downcast: dict | None = None, - ) -> DataFrame | None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().fillna( - self_is_series=False, - value=value, - method=method, - axis=axis, - inplace=inplace, - limit=limit, - downcast=downcast, - ) - - def floordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `floordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "floordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @classmethod - @dataframe_not_implemented() - def from_dict( - cls, data, orient="columns", dtype=None, columns=None - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Construct ``DataFrame`` from dict of array-like or dicts. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_dict( - data, orient=orient, dtype=dtype, columns=columns - ) - ) - - @classmethod - @dataframe_not_implemented() - def from_records( - cls, - data, - index=None, - exclude=None, - columns=None, - coerce_float=False, - nrows=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert structured or record ndarray to ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_records( - data, - index=index, - exclude=exclude, - columns=columns, - coerce_float=coerce_float, - nrows=nrows, - ) - ) - - def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ge", other, axis=axis, level=level) - - def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("gt", other, axis=axis, level=level) - - @dataframe_not_implemented() - def hist( - self, - column=None, - by=None, - grid=True, - xlabelsize=None, - xrot=None, - ylabelsize=None, - yrot=None, - ax=None, - sharex=False, - sharey=False, - figsize=None, - layout=None, - bins=10, - **kwds, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Make a histogram of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.hist, - column=column, - by=by, - grid=grid, - xlabelsize=xlabelsize, - xrot=xrot, - ylabelsize=ylabelsize, - yrot=yrot, - ax=ax, - sharex=sharex, - sharey=sharey, - figsize=figsize, - layout=layout, - bins=bins, - **kwds, - ) - - def info( - self, - verbose: bool | None = None, - buf: IO[str] | None = None, - max_cols: int | None = None, - memory_usage: bool | str | None = None, - show_counts: bool | None = None, - null_counts: bool | None = None, - ): # noqa: PR01, D200 - """ - Print a concise summary of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def put_str(src, output_len=None, spaces=2): - src = str(src) - return src.ljust(output_len if output_len else len(src)) + " " * spaces - - def format_size(num): - for x in ["bytes", "KB", "MB", "GB", "TB"]: - if num < 1024.0: - return f"{num:3.1f} {x}" - num /= 1024.0 - return f"{num:3.1f} PB" - - output = [] - - type_line = str(type(self)) - index_line = "SnowflakeIndex" - columns = self.columns - columns_len = len(columns) - dtypes = self.dtypes - dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" - - if max_cols is None: - max_cols = 100 - - exceeds_info_cols = columns_len > max_cols - - if buf is None: - buf = sys.stdout - - if null_counts is None: - null_counts = not exceeds_info_cols - - if verbose is None: - verbose = not exceeds_info_cols - - if null_counts and verbose: - # We're gonna take items from `non_null_count` in a loop, which - # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here - # that will be faster. - non_null_count = self.count()._to_pandas() - - if memory_usage is None: - memory_usage = True - - def get_header(spaces=2): - output = [] - head_label = " # " - column_label = "Column" - null_label = "Non-Null Count" - dtype_label = "Dtype" - non_null_label = " non-null" - delimiter = "-" - - lengths = {} - lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) - lengths["column"] = max( - len(column_label), max(len(pprint_thing(col)) for col in columns) - ) - lengths["dtype"] = len(dtype_label) - dtype_spaces = ( - max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) - - lengths["dtype"] - ) - - header = put_str(head_label, lengths["head"]) + put_str( - column_label, lengths["column"] - ) - if null_counts: - lengths["null"] = max( - len(null_label), - max(len(pprint_thing(x)) for x in non_null_count) - + len(non_null_label), - ) - header += put_str(null_label, lengths["null"]) - header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) - - output.append(header) - - delimiters = put_str(delimiter * lengths["head"]) + put_str( - delimiter * lengths["column"] - ) - if null_counts: - delimiters += put_str(delimiter * lengths["null"]) - delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) - output.append(delimiters) - - return output, lengths - - output.extend([type_line, index_line]) - - def verbose_repr(output): - columns_line = f"Data columns (total {len(columns)} columns):" - header, lengths = get_header() - output.extend([columns_line, *header]) - for i, col in enumerate(columns): - i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) - - to_append = put_str(f" {i}", lengths["head"]) + put_str( - col_s, lengths["column"] - ) - if null_counts: - non_null = pprint_thing(non_null_count[col]) - to_append += put_str(f"{non_null} non-null", lengths["null"]) - to_append += put_str(dtype, lengths["dtype"], spaces=0) - output.append(to_append) - - def non_verbose_repr(output): - output.append(columns._summary(name="Columns")) - - if verbose: - verbose_repr(output) - else: - non_verbose_repr(output) - - output.append(dtypes_line) - - if memory_usage: - deep = memory_usage == "deep" - mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() - mem_line = f"memory usage: {format_size(mem_usage_bytes)}" - - output.append(mem_line) - - output.append("") - buf.write("\n".join(output)) - - def insert( - self, - loc: int, - column: Hashable, - value: Scalar | AnyArrayLike, - allow_duplicates: bool | NoDefault = no_default, - ) -> None: - """ - Insert column into ``DataFrame`` at specified location. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - raise_if_native_pandas_objects(value) - if allow_duplicates is no_default: - allow_duplicates = False - if not allow_duplicates and column in self.columns: - raise ValueError(f"cannot insert {column}, already exists") - - if not isinstance(loc, int): - raise TypeError("loc must be int") - - # If columns labels are multilevel, we implement following behavior (this is - # name native pandas): - # Case 1: if 'column' is tuple it's length must be same as number of levels - # otherwise raise error. - # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in - # empty strings to match the length of column levels in self frame. - if self.columns.nlevels > 1: - if isinstance(column, tuple) and len(column) != self.columns.nlevels: - # same error as native pandas. - raise ValueError("Item must have length equal to number of levels.") - if not isinstance(column, tuple): - # Fill empty strings to match length of levels - suffix = [""] * (self.columns.nlevels - 1) - column = tuple([column] + suffix) - - # Dictionary keys are treated as index column and this should be joined with - # index of target dataframe. This behavior is similar to 'value' being DataFrame - # or Series, so we simply create Series from dict data here. - if isinstance(value, dict): - value = Series(value, name=column) - - if isinstance(value, DataFrame) or ( - isinstance(value, np.ndarray) and len(value.shape) > 1 - ): - # Supported numpy array shapes are - # 1. (N, ) -> Ex. [1, 2, 3] - # 2. (N, 1) -> Ex> [[1], [2], [3]] - if value.shape[1] != 1: - if isinstance(value, DataFrame): - # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin - raise ValueError( - f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." - ) - else: - raise ValueError( - f"Expected a 1D array, got an array with shape {value.shape}" - ) - # Change numpy array shape from (N, 1) to (N, ) - if isinstance(value, np.ndarray): - value = value.squeeze(axis=1) - - if ( - is_list_like(value) - and not isinstance(value, (Series, DataFrame)) - and len(value) != self.shape[0] - and not 0 == self.shape[0] # dataframe holds no rows - ): - raise ValueError( - "Length of values ({}) does not match length of index ({})".format( - len(value), len(self) - ) - ) - if not -len(self.columns) <= loc <= len(self.columns): - raise IndexError( - f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" - ) - elif loc < 0: - raise ValueError("unbounded slice") - - join_on_index = False - if isinstance(value, (Series, DataFrame)): - value = value._query_compiler - join_on_index = True - elif is_list_like(value): - value = Series(value, name=column)._query_compiler - - new_query_compiler = self._query_compiler.insert( - loc, column, value, join_on_index - ) - # In pandas, 'insert' operation is always inplace. - self._update_inplace(new_query_compiler=new_query_compiler) - - @dataframe_not_implemented() - def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Fill NaN values using an interpolation method. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.interpolate, - method=method, - axis=axis, - limit=limit, - inplace=inplace, - limit_direction=limit_direction, - limit_area=limit_area, - downcast=downcast, - **kwargs, - ) - - def iterrows(self) -> Iterator[tuple[Hashable, Series]]: - """ - Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def iterrow_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - # Raise warning message since iterrows is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") - ) - - partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) - yield from partition_iterator - - def items(self): # noqa: D200 - """ - Iterate over (column name, ``Series``) pairs. - """ - - def items_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - partition_iterator = PartitionIterator(self, 1, items_builder) - yield from partition_iterator - - def itertuples( - self, index: bool = True, name: str | None = "Pandas" - ) -> Iterable[tuple[Any, ...]]: - """ - Iterate over ``DataFrame`` rows as ``namedtuple``-s. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - def itertuples_builder(s): - """Return the next namedtuple.""" - # s is the Series of values in the current row. - fields = [] # column names - data = [] # values under each column - - if index: - data.append(s.name) - fields.append("Index") - - # Fill column names and values. - fields.extend(list(self.columns)) - data.extend(s) - - if name is not None: - # Creating the namedtuple. - itertuple = collections.namedtuple(name, fields, rename=True) - return itertuple._make(data) - - # When the name is None, return a regular tuple. - return tuple(data) - - # Raise warning message since itertuples is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") - ) - return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) - - def join( - self, - other: DataFrame | Series | Iterable[DataFrame | Series], - on: IndexLabel | None = None, - how: str = "left", - lsuffix: str = "", - rsuffix: str = "", - sort: bool = False, - validate: str | None = None, - ) -> DataFrame: - """ - Join columns of another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - for o in other if isinstance(other, list) else [other]: - raise_if_native_pandas_objects(o) - - # Similar to native pandas we implement 'join' using 'pd.merge' method. - # Following code is copied from native pandas (with few changes explained below) - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 - if isinstance(other, Series): - # Same error as native pandas. - if other.name is None: - raise ValueError("Other Series must have a name") - other = DataFrame(other) - elif is_list_like(other): - if any([isinstance(o, Series) and o.name is None for o in other]): - raise ValueError("Other Series must have a name") - - if isinstance(other, DataFrame): - if how == "cross": - return pd.merge( - self, - other, - how=how, - on=on, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - return pd.merge( - self, - other, - left_on=on, - how=how, - left_index=on is None, - right_index=True, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - else: # List of DataFrame/Series - # Same error as native pandas. - if on is not None: - raise ValueError( - "Joining multiple DataFrames only supported for joining on index" - ) - - # Same error as native pandas. - if rsuffix or lsuffix: - raise ValueError( - "Suffixes not supported when joining multiple DataFrames" - ) - - # NOTE: These are not the differences between Snowpark pandas API and pandas behavior - # these are differences between native pandas join behavior when join - # frames have unique index or not. - - # In native pandas logic to join multiple DataFrames/Series is data - # dependent. Under the hood it will either use 'concat' or 'merge' API - # Case 1. If all objects being joined have unique index use 'concat' (axis=1) - # Case 2. Otherwise use 'merge' API by looping through objects left to right. - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 - - # Even though concat (axis=1) and merge are very similar APIs they have - # some differences which leads to inconsistent behavior in native pandas. - # 1. Treatment of un-named Series - # Case #1: Un-named series is allowed in concat API. Objects are joined - # successfully by assigning a number as columns name (see 'concat' API - # documentation for details on treatment of un-named series). - # Case #2: It raises 'ValueError: Other Series must have a name' - - # 2. how='right' - # Case #1: 'concat' API doesn't support right join. It raises - # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' - # Case #2: Merges successfully. - - # 3. Joining frames with duplicate labels but no conflict with other frames - # Example: self = DataFrame(... columns=["A", "B"]) - # other = [DataFrame(... columns=["C", "C"])] - # Case #1: 'ValueError: Indexes have overlapping values' - # Case #2: Merged successfully. - - # In addition to this, native pandas implementation also leads to another - # type of inconsistency where left.join(other, ...) and - # left.join([other], ...) might behave differently for cases mentioned - # above. - # Example: - # import pandas as pd - # df = pd.DataFrame({"a": [4, 5]}) - # other = pd.Series([1, 2]) - # df.join([other]) # this is successful - # df.join(other) # this raises 'ValueError: Other Series must have a name' - - # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API - # to join multiple DataFrame/Series. So always follow the behavior - # documented as Case #2 above. - - joined = self - for frame in other: - if isinstance(frame, DataFrame): - overlapping_cols = set(joined.columns).intersection( - set(frame.columns) - ) - if len(overlapping_cols) > 0: - # Native pandas raises: 'Indexes have overlapping values' - # We differ slightly from native pandas message to make it more - # useful to users. - raise ValueError( - f"Join dataframes have overlapping column labels: {overlapping_cols}" - ) - joined = pd.merge( - joined, - frame, - how=how, - left_index=True, - right_index=True, - validate=validate, - sort=sort, - suffixes=(None, None), - ) - return joined - - def isna(self): - return super().isna() - - def isnull(self): - return super().isnull() - - @dataframe_not_implemented() - def isetitem(self, loc, value): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.isetitem, - loc=loc, - value=value, - ) - - def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("le", other, axis=axis, level=level) - - def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("lt", other, axis=axis, level=level) - - def melt( - self, - id_vars=None, - value_vars=None, - var_name=None, - value_name="value", - col_level=None, - ignore_index=True, - ): # noqa: PR01, RT01, D200 - """ - Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if id_vars is None: - id_vars = [] - if not is_list_like(id_vars): - id_vars = [id_vars] - if value_vars is None: - # Behavior of Index.difference changed in 2.2.x - # https://github.com/pandas-dev/pandas/pull/55113 - # This change needs upstream to Modin: - # https://github.com/modin-project/modin/issues/7206 - value_vars = self.columns.drop(id_vars) - if var_name is None: - columns_name = self._query_compiler.get_index_name(axis=1) - var_name = columns_name if columns_name is not None else "variable" - return self.__constructor__( - query_compiler=self._query_compiler.melt( - id_vars=id_vars, - value_vars=value_vars, - var_name=var_name, - value_name=value_name, - col_level=col_level, - ignore_index=ignore_index, - ) - ) - - @dataframe_not_implemented() - def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 - """ - Return the memory usage of each column in bytes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if index: - result = self._reduce_dimension( - self._query_compiler.memory_usage(index=False, deep=deep) - ) - index_value = self.index.memory_usage(deep=deep) - return pd.concat( - [Series(index_value, index=["Index"]), result] - ) # pragma: no cover - return super().memory_usage(index=index, deep=deep) - - def merge( - self, - right: DataFrame | Series, - how: str = "inner", - on: IndexLabel | None = None, - left_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - right_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - left_index: bool = False, - right_index: bool = False, - sort: bool = False, - suffixes: Suffixes = ("_x", "_y"), - copy: bool = True, - indicator: bool = False, - validate: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Raise error if native pandas objects are passed. - raise_if_native_pandas_objects(right) - - if isinstance(right, Series) and right.name is None: - raise ValueError("Cannot merge a Series without a name") - if not isinstance(right, (Series, DataFrame)): - raise TypeError( - f"Can only merge Series or DataFrame objects, a {type(right)} was passed" - ) - - if isinstance(right, Series): - right_column_nlevels = ( - len(right.name) if isinstance(right.name, tuple) else 1 - ) - else: - right_column_nlevels = right.columns.nlevels - if self.columns.nlevels != right_column_nlevels: - # This is deprecated in native pandas. We raise explicit error for this. - raise ValueError( - "Can not merge objects with different column levels." - + f" ({self.columns.nlevels} levels on the left," - + f" {right_column_nlevels} on the right)" - ) - - # Merge empty native pandas dataframes for error checking. Otherwise, it will - # require a lot of logic to be written. This takes care of raising errors for - # following scenarios: - # 1. Only 'left_index' is set to True. - # 2. Only 'right_index is set to True. - # 3. Only 'left_on' is provided. - # 4. Only 'right_on' is provided. - # 5. 'on' and 'left_on' both are provided - # 6. 'on' and 'right_on' both are provided - # 7. 'on' and 'left_index' both are provided - # 8. 'on' and 'right_index' both are provided - # 9. 'left_on' and 'left_index' both are provided - # 10. 'right_on' and 'right_index' both are provided - # 11. Length mismatch between 'left_on' and 'right_on' - # 12. 'left_index' is not a bool - # 13. 'right_index' is not a bool - # 14. 'on' is not None and how='cross' - # 15. 'left_on' is not None and how='cross' - # 16. 'right_on' is not None and how='cross' - # 17. 'left_index' is True and how='cross' - # 18. 'right_index' is True and how='cross' - # 19. Unknown label in 'on', 'left_on' or 'right_on' - # 20. Provided 'suffixes' is not sufficient to resolve conflicts. - # 21. Merging on column with duplicate labels. - # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} - # 23. conflict with existing labels for array-like join key - # 24. 'indicator' argument is not bool or str - # 25. indicator column label conflicts with existing data labels - create_empty_native_pandas_frame(self).merge( - create_empty_native_pandas_frame(right), - on=on, - how=how, - left_on=replace_external_data_keys_with_empty_pandas_series(left_on), - right_on=replace_external_data_keys_with_empty_pandas_series(right_on), - left_index=left_index, - right_index=right_index, - suffixes=suffixes, - indicator=indicator, - ) - - return self.__constructor__( - query_compiler=self._query_compiler.merge( - right._query_compiler, - how=how, - on=on, - left_on=replace_external_data_keys_with_query_compiler(self, left_on), - right_on=replace_external_data_keys_with_query_compiler( - right, right_on - ), - left_index=left_index, - right_index=right_index, - sort=sort, - suffixes=suffixes, - copy=copy, - indicator=indicator, - validate=validate, - ) - ) - - def mod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `mod`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def mul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - multiply = mul - - def rmul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rmul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get not equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ne`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ne", other, axis=axis, level=level) - - def nlargest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in descending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nlargest(n, columns, keep) - ) - - def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in ascending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nsmallest( - n=n, columns=columns, keep=keep - ) - ) - - def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, - ): - """ - Pivot a level of the (necessarily hierarchical) index labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This ensures that non-pandas MultiIndex objects are caught. - nlevels = self._query_compiler.nlevels() - is_multiindex = nlevels > 1 - - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - - def pivot( - self, - *, - columns: Any, - index: Any | NoDefault = no_default, - values: Any | NoDefault = no_default, - ): - """ - Return reshaped DataFrame organized by given index / column values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if index is no_default: - index = None # pragma: no cover - if values is no_default: - values = None - - # if values is not specified, it should be the remaining columns not in - # index or columns - if values is None: - values = list(self.columns) - if index is not None: - values = [v for v in values if v not in index] - if columns is not None: - values = [v for v in values if v not in columns] - - return self.__constructor__( - query_compiler=self._query_compiler.pivot( - index=index, columns=columns, values=values - ) - ) - - def pivot_table( - self, - values=None, - index=None, - columns=None, - aggfunc="mean", - fill_value=None, - margins=False, - dropna=True, - margins_name="All", - observed=False, - sort=True, - ): - """ - Create a spreadsheet-style pivot table as a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - result = self.__constructor__( - query_compiler=self._query_compiler.pivot_table( - index=index, - values=values, - columns=columns, - aggfunc=aggfunc, - fill_value=fill_value, - margins=margins, - dropna=dropna, - margins_name=margins_name, - observed=observed, - sort=sort, - ) - ) - return result - - @dataframe_not_implemented() - @property - def plot( - self, - x=None, - y=None, - kind="line", - ax=None, - subplots=False, - sharex=None, - sharey=False, - layout=None, - figsize=None, - use_index=True, - title=None, - grid=None, - legend=True, - style=None, - logx=False, - logy=False, - loglog=False, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - rot=None, - fontsize=None, - colormap=None, - table=False, - yerr=None, - xerr=None, - secondary_y=False, - sort_columns=False, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make plots of ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().plot - - def pow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "pow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @dataframe_not_implemented() - def prod( - self, - axis=None, - skipna=True, - numeric_only=False, - min_count=0, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Return the product of the values over the requested axis. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - validate_bool_kwarg(skipna, "skipna", none_allowed=False) - axis = self._get_axis_number(axis) - axis_to_apply = self.columns if axis else self.index - if ( - skipna is not False - and numeric_only is None - and min_count > len(axis_to_apply) - ): - new_index = self.columns if not axis else self.index - return Series( - [np.nan] * len(new_index), index=new_index, dtype=np.dtype("object") - ) - - data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True) - if min_count > 1: - return data._reduce_dimension( - data._query_compiler.prod_min_count( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - return data._reduce_dimension( - data._query_compiler.prod( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - - product = prod - - def quantile( - self, - q: Scalar | ListLike = 0.5, - axis: Axis = 0, - numeric_only: bool = False, - interpolation: Literal[ - "linear", "lower", "higher", "midpoint", "nearest" - ] = "linear", - method: Literal["single", "table"] = "single", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().quantile( - q=q, - axis=axis, - numeric_only=numeric_only, - interpolation=interpolation, - method=method, - ) - - @dataframe_not_implemented() - def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Query the columns of a ``DataFrame`` with a boolean expression. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_var_dicts_in_kwargs(expr, kwargs) - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.query(expr, **kwargs) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rename( - self, - mapper: Renamer | None = None, - *, - index: Renamer | None = None, - columns: Renamer | None = None, - axis: Axis | None = None, - copy: bool | None = None, - inplace: bool = False, - level: Level | None = None, - errors: IgnoreRaise = "ignore", - ) -> DataFrame | None: - """ - Alter axes labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if mapper is None and index is None and columns is None: - raise TypeError("must pass an index to rename") - - if index is not None or columns is not None: - if axis is not None: - raise TypeError( - "Cannot specify both 'axis' and any of 'index' or 'columns'" - ) - elif mapper is not None: - raise TypeError( - "Cannot specify both 'mapper' and any of 'index' or 'columns'" - ) - else: - # use the mapper argument - if axis and self._get_axis_number(axis) == 1: - columns = mapper - else: - index = mapper - - if copy is not None: - WarningMessage.ignored_argument( - operation="dataframe.rename", - argument="copy", - message="copy parameter has been ignored with Snowflake execution engine", - ) - - if isinstance(index, dict): - index = Series(index) - - new_qc = self._query_compiler.rename( - index_renamer=index, columns_renamer=columns, level=level, errors=errors - ) - return self._create_or_update_from_compiler( - new_query_compiler=new_qc, inplace=inplace - ) - - def reindex( - self, - labels=None, - index=None, - columns=None, - axis=None, - method=None, - copy=None, - level=None, - fill_value=np.nan, - limit=None, - tolerance=None, - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - axis = self._get_axis_number(axis) - if axis == 0 and labels is not None: - index = labels - elif labels is not None: - columns = labels - return super().reindex( - index=index, - columns=columns, - method=method, - copy=copy, - level=level, - fill_value=fill_value, - limit=limit, - tolerance=tolerance, - ) - - @dataframe_not_implemented() - def reindex_like( - self, - other, - method=None, - copy: bool | None = None, - limit=None, - tolerance=None, - ) -> DataFrame: # pragma: no cover - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy is None: - copy = True - # docs say "Same as calling .reindex(index=other.index, columns=other.columns,...).": - # https://pandas.pydata.org/pandas-docs/version/1.4/reference/api/pandas.DataFrame.reindex_like.html - return self.reindex( - index=other.index, - columns=other.columns, - method=method, - copy=copy, - limit=limit, - tolerance=tolerance, - ) - - def replace( - self, - to_replace=None, - value=no_default, - inplace: bool = False, - limit=None, - regex: bool = False, - method: str | NoDefault = no_default, - ): - """ - Replace values given in `to_replace` with `value`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.replace( - to_replace=to_replace, - value=value, - limit=limit, - regex=regex, - method=method, - ) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rfloordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `rfloordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rfloordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def radd( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `radd`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "radd", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rmod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `rmod`). - """ - return self._binary_op( - "rmod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 - return super().round(decimals, args=args, **kwargs) - - def rpow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rpow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rsub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `rsub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rsub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rtruediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `rtruediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rtruediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - rdiv = rtruediv - - def select_dtypes( - self, - include: ListLike | str | type | None = None, - exclude: ListLike | str | type | None = None, - ) -> DataFrame: - """ - Return a subset of the ``DataFrame``'s columns based on the column dtypes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This line defers argument validation to pandas, which will raise errors on our behalf in cases - # like if `include` and `exclude` are None, the same type is specified in both lists, or a string - # dtype (as opposed to object) is specified. - pandas.DataFrame().select_dtypes(include, exclude) - - if include and not is_list_like(include): - include = [include] - elif include is None: - include = [] - if exclude and not is_list_like(exclude): - exclude = [exclude] - elif exclude is None: - exclude = [] - - sel = tuple(map(set, (include, exclude))) - - # The width of the np.int_/float_ alias differs between Windows and other platforms, so - # we need to include a workaround. - # https://github.com/numpy/numpy/issues/9464 - # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 - def check_sized_number_infer_dtypes(dtype): - if (isinstance(dtype, str) and dtype == "int") or (dtype is int): - return [np.int32, np.int64] - elif dtype == "float" or dtype is float: - return [np.float64, np.float32] - else: - return [infer_dtype_from_object(dtype)] - - include, exclude = map( - lambda x: set( - itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) - ), - sel, - ) - # We need to index on column position rather than label in case of duplicates - include_these = pandas.Series(not bool(include), index=range(len(self.columns))) - exclude_these = pandas.Series(not bool(exclude), index=range(len(self.columns))) - - def is_dtype_instance_mapper(dtype): - return functools.partial(issubclass, dtype.type) - - for i, dtype in enumerate(self.dtypes): - if include: - include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) - if exclude: - exclude_these[i] = not any( - map(is_dtype_instance_mapper(dtype), exclude) - ) - - dtype_indexer = include_these & exclude_these - indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] - # We need to use iloc instead of drop in case of duplicate column names - return self.iloc[:, indicate] - - def shift( - self, - periods: int | Sequence[int] = 1, - freq=None, - axis: Axis = 0, - fill_value: Hashable = no_default, - suffix: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().shift(periods, freq, axis, fill_value, suffix) - - def set_index( - self, - keys: IndexLabel - | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], - drop: bool = True, - append: bool = False, - inplace: bool = False, - verify_integrity: bool = False, - ) -> None | DataFrame: - """ - Set the ``DataFrame`` index using existing columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if not isinstance(keys, list): - keys = [keys] - - # make sure key is either hashable, index, or series - label_or_series = [] - - missing = [] - columns = self.columns.tolist() - for key in keys: - raise_if_native_pandas_objects(key) - if isinstance(key, pd.Series): - label_or_series.append(key._query_compiler) - elif isinstance(key, (np.ndarray, list, Iterator)): - label_or_series.append(pd.Series(key)._query_compiler) - elif isinstance(key, (pd.Index, pandas.MultiIndex)): - label_or_series += [ - s._query_compiler for s in self._to_series_list(key) - ] - else: - if not is_hashable(key): - raise TypeError( - f'The parameter "keys" may be a column key, one-dimensional array, or a list ' - f"containing only valid column keys and one-dimensional arrays. Received column " - f"of type {type(key)}" - ) - label_or_series.append(key) - found = key in columns - if columns.count(key) > 1: - raise ValueError(f"The column label '{key}' is not unique") - elif not found: - missing.append(key) - - if missing: - raise KeyError(f"None of {missing} are in the columns") - - new_query_compiler = self._query_compiler.set_index( - label_or_series, drop=drop, append=append - ) - - # TODO: SNOW-782633 improve this code once duplicate is supported - # this needs to pull all index which is inefficient - if verify_integrity and not new_query_compiler.index.is_unique: - duplicates = new_query_compiler.index[ - new_query_compiler.index.to_pandas().duplicated() - ].unique() - raise ValueError(f"Index has duplicate keys: {duplicates}") - - return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) - - sparse = CachedAccessor("sparse", SparseFrameAccessor) - - def squeeze(self, axis: Axis | None = None): - """ - Squeeze 1 dimensional axis objects into scalars. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) if axis is not None else None - len_columns = self._query_compiler.get_axis_len(1) - if axis == 1 and len_columns == 1: - return Series(query_compiler=self._query_compiler) - if axis in [0, None]: - # get_axis_len(0) results in a sql query to count number of rows in current - # dataframe. We should only compute len_index if axis is 0 or None. - len_index = len(self) - if axis is None and (len_columns == 1 or len_index == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 0 and len_index == 1: - return Series(query_compiler=self.T._query_compiler) - return self.copy() - - def stack( - self, - level: int | str | list = -1, - dropna: bool | NoDefault = no_default, - sort: bool | NoDefault = no_default, - future_stack: bool = False, # ignored - ): - """ - Stack the prescribed level(s) from columns to index. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if future_stack is not False: - WarningMessage.ignored_argument( # pragma: no cover - operation="DataFrame.stack", - argument="future_stack", - message="future_stack parameter has been ignored with Snowflake execution engine", - ) - if dropna is NoDefault: - dropna = True # pragma: no cover - if sort is NoDefault: - sort = True # pragma: no cover - - # This ensures that non-pandas MultiIndex objects are caught. - is_multiindex = len(self.columns.names) > 1 - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - - def sub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `sub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "sub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - subtract = sub - - @dataframe_not_implemented() - def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to the binary Feather format. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas(pandas.DataFrame.to_feather, path, **kwargs) - - @dataframe_not_implemented() - def to_gbq( - self, - destination_table, - project_id=None, - chunksize=None, - reauth=False, - if_exists="fail", - auth_local_webserver=True, - table_schema=None, - location=None, - progress_bar=True, - credentials=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to a Google BigQuery table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functionsf - return self._default_to_pandas( - pandas.DataFrame.to_gbq, - destination_table, - project_id=project_id, - chunksize=chunksize, - reauth=reauth, - if_exists=if_exists, - auth_local_webserver=auth_local_webserver, - table_schema=table_schema, - location=location, - progress_bar=progress_bar, - credentials=credentials, - ) - - @dataframe_not_implemented() - def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_orc, - path=path, - engine=engine, - index=index, - engine_kwargs=engine_kwargs, - ) - - @dataframe_not_implemented() - def to_html( - self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - bold_rows=True, - classes=None, - escape=True, - notebook=False, - border=None, - table_id=None, - render_links=False, - encoding=None, - ): # noqa: PR01, RT01, D200 - """ - Render a ``DataFrame`` as an HTML table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_html, - buf=buf, - columns=columns, - col_space=col_space, - header=header, - index=index, - na_rep=na_rep, - formatters=formatters, - float_format=float_format, - sparsify=sparsify, - index_names=index_names, - justify=justify, - max_rows=max_rows, - max_cols=max_cols, - show_dimensions=show_dimensions, - decimal=decimal, - bold_rows=bold_rows, - classes=classes, - escape=escape, - notebook=notebook, - border=border, - table_id=table_id, - render_links=render_links, - encoding=None, - ) - - @dataframe_not_implemented() - def to_parquet( - self, - path=None, - engine="auto", - compression="snappy", - index=None, - partition_cols=None, - storage_options: StorageOptions = None, - **kwargs, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - from snowflake.snowpark.modin.pandas.dispatching.factories.dispatcher import ( - FactoryDispatcher, - ) - - return FactoryDispatcher.to_parquet( - self._query_compiler, - path=path, - engine=engine, - compression=compression, - index=index, - partition_cols=partition_cols, - storage_options=storage_options, - **kwargs, - ) - - @dataframe_not_implemented() - def to_period( - self, freq=None, axis=0, copy=True - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` from ``DatetimeIndex`` to ``PeriodIndex``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_period(freq=freq, axis=axis, copy=copy) - - @dataframe_not_implemented() - def to_records( - self, index=True, column_dtypes=None, index_dtypes=None - ): # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` to a NumPy record array. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_records, - index=index, - column_dtypes=column_dtypes, - index_dtypes=index_dtypes, - ) - - @dataframe_not_implemented() - def to_stata( - self, - path: FilePath | WriteBuffer[bytes], - convert_dates: dict[Hashable, str] | None = None, - write_index: bool = True, - byteorder: str | None = None, - time_stamp: datetime.datetime | None = None, - data_label: str | None = None, - variable_labels: dict[Hashable, str] | None = None, - version: int | None = 114, - convert_strl: Sequence[Hashable] | None = None, - compression: CompressionOptions = "infer", - storage_options: StorageOptions = None, - *, - value_labels: dict[Hashable, dict[float | int, str]] | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_stata, - path, - convert_dates=convert_dates, - write_index=write_index, - byteorder=byteorder, - time_stamp=time_stamp, - data_label=data_label, - variable_labels=variable_labels, - version=version, - convert_strl=convert_strl, - compression=compression, - storage_options=storage_options, - value_labels=value_labels, - ) - - @dataframe_not_implemented() - def to_xml( - self, - path_or_buffer=None, - index=True, - root_name="data", - row_name="row", - na_rep=None, - attr_cols=None, - elem_cols=None, - namespaces=None, - prefix=None, - encoding="utf-8", - xml_declaration=True, - pretty_print=True, - parser="lxml", - stylesheet=None, - compression="infer", - storage_options=None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.default_to_pandas( - pandas.DataFrame.to_xml, - path_or_buffer=path_or_buffer, - index=index, - root_name=root_name, - row_name=row_name, - na_rep=na_rep, - attr_cols=attr_cols, - elem_cols=elem_cols, - namespaces=namespaces, - prefix=prefix, - encoding=encoding, - xml_declaration=xml_declaration, - pretty_print=pretty_print, - parser=parser, - stylesheet=stylesheet, - compression=compression, - storage_options=storage_options, - ) - ) - - def to_dict( - self, - orient: Literal[ - "dict", "list", "series", "split", "tight", "records", "index" - ] = "dict", - into: type[dict] = dict, - ) -> dict | list[dict]: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().to_dict(orient=orient, into=into) - - def to_timestamp( - self, freq=None, how="start", axis=0, copy=True - ): # noqa: PR01, RT01, D200 - """ - Cast to DatetimeIndex of timestamps, at *beginning* of period. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_timestamp(freq=freq, how=how, axis=axis, copy=copy) - - def truediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `truediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "truediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - div = divide = truediv - - def update( - self, other, join="left", overwrite=True, filter_func=None, errors="ignore" - ): # noqa: PR01, RT01, D200 - """ - Modify in place using non-NA values from another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - other = self.__constructor__(other) - query_compiler = self._query_compiler.df_update( - other._query_compiler, - join=join, - overwrite=overwrite, - filter_func=filter_func, - errors=errors, - ) - self._update_inplace(new_query_compiler=query_compiler) - - def diff( - self, - periods: int = 1, - axis: Axis = 0, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().diff( - periods=periods, - axis=axis, - ) - - def drop( - self, - labels: IndexLabel = None, - axis: Axis = 0, - index: IndexLabel = None, - columns: IndexLabel = None, - level: Level = None, - inplace: bool = False, - errors: IgnoreRaise = "raise", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().drop( - labels=labels, - axis=axis, - index=index, - columns=columns, - level=level, - inplace=inplace, - errors=errors, - ) - - def value_counts( - self, - subset: Sequence[Hashable] | None = None, - normalize: bool = False, - sort: bool = True, - ascending: bool = False, - dropna: bool = True, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series( - query_compiler=self._query_compiler.value_counts( - subset=subset, - normalize=normalize, - sort=sort, - ascending=ascending, - dropna=dropna, - ), - name="proportion" if normalize else "count", - ) - - def mask( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.mask requires an axis parameter (0 or 1) when given a Series" - ) - - return super().mask( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - def where( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - """ - Replace values where the condition is False. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.where requires an axis parameter (0 or 1) when given a Series" - ) - - return super().where( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - @dataframe_not_implemented() - def xs(self, key, axis=0, level=None, drop_level=True): # noqa: PR01, RT01, D200 - """ - Return cross-section from the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.xs, key, axis=axis, level=level, drop_level=drop_level - ) - - def set_axis( - self, - labels: IndexLabel, - *, - axis: Axis = 0, - copy: bool | NoDefault = no_default, # ignored - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not is_scalar(axis): - raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") - return super().set_axis( - labels=labels, - # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. - axis=pandas.DataFrame._get_axis_name(axis), - copy=copy, - ) - - def __getattr__(self, key): - """ - Return item identified by `key`. - - Parameters - ---------- - key : hashable - Key to get. - - Returns - ------- - Any - - Notes - ----- - First try to use `__getattribute__` method. If it fails - try to get `key` from ``DataFrame`` fields. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - try: - return object.__getattribute__(self, key) - except AttributeError as err: - if key not in _ATTRS_NO_LOOKUP and key in self.columns: - return self[key] - raise err - - def __setattr__(self, key, value): - """ - Set attribute `value` identified by `key`. - - Parameters - ---------- - key : hashable - Key to set. - value : Any - Value to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # While we let users assign to a column labeled "x" with "df.x" , there - # are some attributes that we should assume are NOT column names and - # therefore should follow the default Python object assignment - # behavior. These are: - # - anything in self.__dict__. This includes any attributes that the - # user has added to the dataframe with, e.g., `df.c = 3`, and - # any attribute that Modin has added to the frame, e.g. - # `_query_compiler` and `_siblings` - # - `_query_compiler`, which Modin initializes before it appears in - # __dict__ - # - `_siblings`, which Modin initializes before it appears in __dict__ - # - `_cache`, which pandas.cache_readonly uses to cache properties - # before it appears in __dict__. - if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__: - pass - elif key in self and key not in dir(self): - self.__setitem__(key, value) - # Note: return immediately so we don't keep this `key` as dataframe state. - # `__getattr__` will return the columns not present in `dir(self)`, so we do not need - # to manually track this state in the `dir`. - return - elif is_list_like(value) and key not in ["index", "columns"]: - WarningMessage.single_warning( - SET_DATAFRAME_ATTRIBUTE_WARNING - ) # pragma: no cover - object.__setattr__(self, key, value) - - def __setitem__(self, key: Any, value: Any): - """ - Set attribute `value` identified by `key`. - - Args: - key: Key to set - value: Value to set - - Note: - In the case where value is any list like or array, pandas checks the array length against the number of rows - of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw - a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if - the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use - enlargement filling with the last value in the array. - - Returns: - None - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - key = apply_if_callable(key, self) - if isinstance(key, DataFrame) or ( - isinstance(key, np.ndarray) and len(key.shape) == 2 - ): - # This case uses mask's codepath to perform the set, but - # we need to duplicate the code here since we are passing - # an additional kwarg `cond_fillna_with_true` to the QC here. - # We need this additional kwarg, since if df.shape - # and key.shape do not align (i.e. df has more rows), - # mask's codepath would mask the additional rows in df - # while for setitem, we need to keep the original values. - if not isinstance(key, DataFrame): - if key.dtype != bool: - raise TypeError( - "Must pass DataFrame or 2-d ndarray with boolean values only" - ) - key = DataFrame(key) - key._query_compiler._shape_hint = "array" - - if value is not None: - value = apply_if_callable(value, self) - - if isinstance(value, np.ndarray): - value = DataFrame(value) - value._query_compiler._shape_hint = "array" - elif isinstance(value, pd.Series): - # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this - # error instead, since it is more descriptive. - raise ValueError( - "setitem with a 2D key does not support Series values." - ) - - if isinstance(value, BasePandasDataset): - value = value._query_compiler - - query_compiler = self._query_compiler.mask( - cond=key._query_compiler, - other=value, - axis=None, - level=None, - cond_fillna_with_true=True, - ) - - return self._create_or_update_from_compiler(query_compiler, inplace=True) - - # Error Checking: - if (isinstance(key, pd.Series) or is_list_like(key)) and ( - isinstance(value, range) - ): - raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) - elif isinstance(value, slice): - # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. - raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) - - # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column - # key. - index, columns = slice(None), key - index_is_bool_indexer = False - if isinstance(key, slice): - if is_integer(key.start) and is_integer(key.stop): - # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as - # df.iloc[1:2, :] = val - self.iloc[key] = value - return - index, columns = key, slice(None) - elif isinstance(key, pd.Series): - if is_bool_dtype(key.dtype): - index, columns = key, slice(None) - index_is_bool_indexer = True - elif is_bool_indexer(key): - index, columns = pd.Series(key), slice(None) - index_is_bool_indexer = True - - # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case - # we have to explicitly set matching_item_columns_by_label to False for setitem. - index = index._query_compiler if isinstance(index, BasePandasDataset) else index - columns = ( - columns._query_compiler - if isinstance(columns, BasePandasDataset) - else columns - ) - from .indexing import is_2d_array - - matching_item_rows_by_label = not is_2d_array(value) - if is_2d_array(value): - value = DataFrame(value) - item = value._query_compiler if isinstance(value, BasePandasDataset) else value - new_qc = self._query_compiler.set_2d_labels( - index, - columns, - item, - # setitem always matches item by position - matching_item_columns_by_label=False, - matching_item_rows_by_label=matching_item_rows_by_label, - index_is_bool_indexer=index_is_bool_indexer, - # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling - # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the - # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have - # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns - # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", - # "X", "X". - deduplicate_columns=True, - ) - return self._update_inplace(new_query_compiler=new_qc) - - def abs(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().abs() - - def __and__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__and__", other, axis=1) - - def __rand__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__rand__", other, axis=1) - - def __or__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__or__", other, axis=1) - - def __ror__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__ror__", other, axis=1) - - def __neg__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().__neg__() - - def __iter__(self): - """ - Iterate over info axis. - - Returns - ------- - iterable - Iterator of the columns names. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return iter(self.columns) - - def __contains__(self, key): - """ - Check if `key` in the ``DataFrame.columns``. - - Parameters - ---------- - key : hashable - Key to check the presence in the columns. - - Returns - ------- - bool - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns.__contains__(key) - - def __round__(self, decimals=0): - """ - Round each value in a ``DataFrame`` to the given number of decimals. - - Parameters - ---------- - decimals : int, default: 0 - Number of decimal places to round to. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().round(decimals) - - @dataframe_not_implemented() - def __delitem__(self, key): - """ - Delete item identified by `key` label. - - Parameters - ---------- - key : hashable - Key to delete. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if key not in self: - raise KeyError(key) - self._update_inplace(new_query_compiler=self._query_compiler.delitem(key)) - - __add__ = add - __iadd__ = add # pragma: no cover - __radd__ = radd - __mul__ = mul - __imul__ = mul # pragma: no cover - __rmul__ = rmul - __pow__ = pow - __ipow__ = pow # pragma: no cover - __rpow__ = rpow - __sub__ = sub - __isub__ = sub # pragma: no cover - __rsub__ = rsub - __floordiv__ = floordiv - __ifloordiv__ = floordiv # pragma: no cover - __rfloordiv__ = rfloordiv - __truediv__ = truediv - __itruediv__ = truediv # pragma: no cover - __rtruediv__ = rtruediv - __mod__ = mod - __imod__ = mod # pragma: no cover - __rmod__ = rmod - __rdiv__ = rdiv - - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): - """ - Get a Modin DataFrame that implements the dataframe exchange protocol. - - See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. - - Parameters - ---------- - nan_as_null : bool, default: False - A keyword intended for the consumer to tell the producer - to overwrite null values in the data with ``NaN`` (or ``NaT``). - This currently has no effect; once support for nullable extension - dtypes is added, this value should be propagated to columns. - allow_copy : bool, default: True - A keyword that defines whether or not the library is allowed - to make a copy of the data. For example, copying data would be necessary - if a library supports strided buffers, given that this protocol - specifies contiguous buffers. Currently, if the flag is set to ``False`` - and a copy is needed, a ``RuntimeError`` will be raised. - - Returns - ------- - ProtocolDataframe - A dataframe object following the dataframe protocol specification. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - ErrorMessage.not_implemented( - "Snowpark pandas does not support the DataFrame interchange " - + "protocol method `__dataframe__`. To use Snowpark pandas " - + "DataFrames with third-party libraries that try to call the " - + "`__dataframe__` method, please convert this Snowpark pandas " - + "DataFrame to pandas with `to_pandas()`." - ) - - return self._query_compiler.to_dataframe( - nan_as_null=nan_as_null, allow_copy=allow_copy - ) - - @dataframe_not_implemented() - @property - def attrs(self): # noqa: RT01, D200 - """ - Return dictionary of global attributes of this dataset. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def attrs(df): - return df.attrs - - return self._default_to_pandas(attrs) - - @dataframe_not_implemented() - @property - def style(self): # noqa: RT01, D200 - """ - Return a Styler object. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def style(df): - """Define __name__ attr because properties do not have it.""" - return df.style - - return self._default_to_pandas(style) - - def isin( - self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(values, dict): - return super().isin(values) - elif isinstance(values, Series): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not values.index.is_unique: - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - elif isinstance(values, DataFrame): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not (values.columns.is_unique and values.index.is_unique): - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - else: - if not is_list_like(values): - # throw pandas compatible error - raise TypeError( - "only list-like or dict-like objects are allowed " - f"to be passed to {self.__class__.__name__}.isin(), " - f"you passed a '{type(values).__name__}'" - ) - return super().isin(values) - - def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a ``DataFrame`` with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - DataFrame or None - None if update was done, ``DataFrame`` otherwise. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace: - return self.__constructor__(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - def _get_numeric_data(self, axis: int): - """ - Grab only numeric data from ``DataFrame``. - - Parameters - ---------- - axis : {0, 1} - Axis to inspect on having numeric types only. - - Returns - ------- - DataFrame - ``DataFrame`` with numeric data. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas ignores `numeric_only` if `axis` is 1, but we do have to drop - # non-numeric columns if `axis` is 0. - if axis != 0: - return self - return self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - - def _validate_dtypes(self, numeric_only=False): - """ - Check that all the dtypes are the same. - - Parameters - ---------- - numeric_only : bool, default: False - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - dtype = self.dtypes[0] - for t in self.dtypes: - if numeric_only and not is_numeric_dtype(t): - raise TypeError(f"{t} is not a numeric data type") - elif not numeric_only and t != dtype: - raise TypeError(f"Cannot compare type '{t}' with type '{dtype}'") - - def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False): - """ - Validate data dtype for `sum`, `prod` and `mean` methods. - - Parameters - ---------- - axis : {0, 1} - Axis to validate over. - numeric_only : bool - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - ignore_axis : bool, default: False - Whether or not to ignore `axis` parameter. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # We cannot add datetime types, so if we are summing a column with - # dtype datetime64 and cannot ignore non-numeric types, we must throw a - # TypeError. - if ( - not axis - and numeric_only is False - and any(dtype == np.dtype("datetime64[ns]") for dtype in self.dtypes) - ): - raise TypeError("Cannot add Timestamp Types") - - # If our DataFrame has both numeric and non-numeric dtypes then - # operations between these types do not make sense and we must raise a - # TypeError. The exception to this rule is when there are datetime and - # timedelta objects, in which case we proceed with the comparison - # without ignoring any non-numeric types. We must check explicitly if - # numeric_only is False because if it is None, it will default to True - # if the operation fails with mixed dtypes. - if ( - (axis or ignore_axis) - and numeric_only is False - and np.unique([is_numeric_dtype(dtype) for dtype in self.dtypes]).size == 2 - ): - # check if there are columns with dtypes datetime or timedelta - if all( - dtype != np.dtype("datetime64[ns]") - and dtype != np.dtype("timedelta64[ns]") - for dtype in self.dtypes - ): - raise TypeError("Cannot operate on Numeric and Non-Numeric Types") - - return self._get_numeric_data(axis) if numeric_only else self - - def _to_pandas( - self, - *, - statement_params: dict[str, str] | None = None, - **kwargs: Any, - ) -> pandas.DataFrame: - """ - Convert Snowpark pandas DataFrame to pandas DataFrame - - Args: - statement_params: Dictionary of statement level parameters to be set while executing this action. - - Returns: - pandas DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.to_pandas( - statement_params=statement_params, **kwargs - ) - - def _validate_eval_query(self, expr, **kwargs): - """ - Validate the arguments of ``eval`` and ``query`` functions. - - Parameters - ---------- - expr : str - The expression to evaluate. This string cannot contain any - Python statements, only Python expressions. - **kwargs : dict - Optional arguments of ``eval`` and ``query`` functions. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(expr, str) and expr == "": - raise ValueError("expr cannot be an empty string") - - if isinstance(expr, str) and "not" in expr: - if "parser" in kwargs and kwargs["parser"] == "python": - ErrorMessage.not_implemented( # pragma: no cover - "Snowpark pandas does not yet support 'not' in the " - + "expression for the methods `DataFrame.eval` or " - + "`DataFrame.query`" - ) - - def _reduce_dimension(self, query_compiler): - """ - Reduce the dimension of data from the `query_compiler`. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to retrieve the data. - - Returns - ------- - Series - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series(query_compiler=query_compiler) - - def _set_axis_name(self, name, axis=0, inplace=False): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - renamed = self if inplace else self.copy() - if axis == 0: - renamed.index = renamed.index.set_names(name) - else: - renamed.columns = renamed.columns.set_names(name) - if not inplace: - return renamed - - def _to_datetime(self, **kwargs): - """ - Convert `self` to datetime. - - Parameters - ---------- - **kwargs : dict - Optional arguments to use during query compiler's - `to_datetime` invocation. - - Returns - ------- - Series of datetime64 dtype - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._reduce_dimension( - query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) - ) - - # Persistance support methods - BEGIN - @classmethod - def _inflate_light(cls, query_compiler): - """ - Re-creates the object from previously-serialized lightweight representation. - - The method is used for faster but not disk-storable persistence. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `query_compiler`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(query_compiler=query_compiler) - - @classmethod - def _inflate_full(cls, pandas_df): - """ - Re-creates the object from previously-serialized disk-storable representation. - - Parameters - ---------- - pandas_df : pandas.DataFrame - Data to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `pandas_df`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(data=from_pandas(pandas_df)) - - @dataframe_not_implemented() - def __reduce__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._query_compiler.finalize() - # if PersistentPickle.get(): - # return self._inflate_full, (self._to_pandas(),) - return self._inflate_light, (self._query_compiler,) - - # Persistance support methods - END diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index d5d158373de..5024d0618ac 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -31,7 +31,7 @@ import numpy as np import pandas import pandas.core.common as common -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import IntervalIndex, NaT, Timedelta, Timestamp from pandas._libs import NaTType, lib @@ -65,7 +65,6 @@ # add this line to make doctests runnable from snowflake.snowpark.modin import pandas as pd # noqa: F401 -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import ( is_scalar, raise_if_native_pandas_objects, @@ -92,10 +91,9 @@ # linking to `snowflake.snowpark.DataFrame`, we need to explicitly # qualify return types in this file with `modin.pandas.DataFrame`. # SNOW-1233342: investigate how to fix these links without using absolute paths + import modin from modin.core.storage_formats import BaseQueryCompiler # pragma: no cover - import snowflake # pragma: no cover - _logger = getLogger(__name__) VALID_DATE_TYPE = Union[ @@ -137,8 +135,8 @@ def notna(obj): # noqa: PR01, RT01, D200 @snowpark_pandas_telemetry_standalone_function_decorator def merge( - left: snowflake.snowpark.modin.pandas.DataFrame | Series, - right: snowflake.snowpark.modin.pandas.DataFrame | Series, + left: modin.pandas.DataFrame | Series, + right: modin.pandas.DataFrame | Series, how: str | None = "inner", on: IndexLabel | None = None, left_on: None @@ -414,7 +412,7 @@ def merge_asof( tolerance: int | Timedelta | None = None, allow_exact_matches: bool = True, direction: str = "backward", -) -> snowflake.snowpark.modin.pandas.DataFrame: +) -> modin.pandas.DataFrame: """ Perform a merge by key distance. @@ -1047,7 +1045,7 @@ def unique(values) -> np.ndarray: >>> pd.unique([pd.Timestamp('2016-01-01', tz='US/Eastern') ... for _ in range(3)]) - array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')], + array([Timestamp('2016-01-01 00:00:00-0500', tz='UTC-05:00')], dtype=object) >>> pd.unique([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]) @@ -1105,8 +1103,8 @@ def value_counts( @snowpark_pandas_telemetry_standalone_function_decorator def concat( objs: ( - Iterable[snowflake.snowpark.modin.pandas.DataFrame | Series] - | Mapping[Hashable, snowflake.snowpark.modin.pandas.DataFrame | Series] + Iterable[modin.pandas.DataFrame | Series] + | Mapping[Hashable, modin.pandas.DataFrame | Series] ), axis: Axis = 0, join: str = "outer", @@ -1117,7 +1115,7 @@ def concat( verify_integrity: bool = False, sort: bool = False, copy: bool = True, -) -> snowflake.snowpark.modin.pandas.DataFrame | Series: +) -> modin.pandas.DataFrame | Series: """ Concatenate pandas objects along a particular axis. @@ -1490,7 +1488,7 @@ def concat( def to_datetime( arg: DatetimeScalarOrArrayConvertible | DictConvertible - | snowflake.snowpark.modin.pandas.DataFrame + | modin.pandas.DataFrame | Series, errors: DateTimeErrorChoices = "raise", dayfirst: bool = False, @@ -1750,35 +1748,35 @@ def to_datetime( DatetimeIndex(['2018-10-26 12:00:00', '2018-10-26 13:00:15'], dtype='datetime64[ns]', freq=None) >>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500']) - DatetimeIndex(['2018-10-26 10:00:00-07:00', '2018-10-26 11:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 12:00:00-05:00', '2018-10-26 13:00:00-05:00'], dtype='datetime64[ns, UTC-05:00]', freq=None) - Use right format to convert to timezone-aware type (Note that when call Snowpark pandas API to_pandas() the timezone-aware output will always be converted to session timezone): >>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500'], format="%Y-%m-%d %H:%M:%S %z") - DatetimeIndex(['2018-10-26 10:00:00-07:00', '2018-10-26 11:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 12:00:00-05:00', '2018-10-26 13:00:00-05:00'], dtype='datetime64[ns, UTC-05:00]', freq=None) - Timezone-aware inputs *with mixed time offsets* (for example issued from a timezone with daylight savings, such as Europe/Paris): >>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100']) - DatetimeIndex(['2020-10-24 17:00:00-07:00', '2020-10-24 20:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex([2020-10-25 02:00:00+02:00, 2020-10-25 04:00:00+01:00], dtype='object', freq=None) >>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100'], format="%Y-%m-%d %H:%M:%S %z") - DatetimeIndex(['2020-10-24 17:00:00-07:00', '2020-10-24 20:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex([2020-10-25 02:00:00+02:00, 2020-10-25 04:00:00+01:00], dtype='object', freq=None) Setting ``utc=True`` makes sure always convert to timezone-aware outputs: - Timezone-naive inputs are *localized* based on the session timezone >>> pd.to_datetime(['2018-10-26 12:00', '2018-10-26 13:00'], utc=True) - DatetimeIndex(['2018-10-26 05:00:00-07:00', '2018-10-26 06:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 12:00:00+00:00', '2018-10-26 13:00:00+00:00'], dtype='datetime64[ns, UTC]', freq=None) - Timezone-aware inputs are *converted* to session timezone >>> pd.to_datetime(['2018-10-26 12:00:00 -0530', '2018-10-26 12:00:00 -0500'], ... utc=True) - DatetimeIndex(['2018-10-26 10:30:00-07:00', '2018-10-26 10:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 17:30:00+00:00', '2018-10-26 17:00:00+00:00'], dtype='datetime64[ns, UTC]', freq=None) """ # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py raise_if_native_pandas_objects(arg) diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index c672f04da63..5da10d9b7a6 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -45,6 +45,7 @@ import pandas from modin.pandas import Series from modin.pandas.base import BasePandasDataset +from modin.pandas.dataframe import DataFrame from pandas._libs.tslibs import Resolution, parsing from pandas._typing import AnyArrayLike, Scalar from pandas.api.types import is_bool, is_list_like @@ -61,7 +62,6 @@ import snowflake.snowpark.modin.pandas as pd import snowflake.snowpark.modin.pandas.utils as frontend_utils -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin._internal.indexing_utils import ( MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE, diff --git a/src/snowflake/snowpark/modin/pandas/io.py b/src/snowflake/snowpark/modin/pandas/io.py index 25959212a18..b92e8ee3582 100644 --- a/src/snowflake/snowpark/modin/pandas/io.py +++ b/src/snowflake/snowpark/modin/pandas/io.py @@ -92,7 +92,7 @@ # below logic is to handle circular imports without errors if TYPE_CHECKING: # pragma: no cover - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available @@ -106,7 +106,7 @@ class ModinObjects: def DataFrame(cls): """Get ``modin.pandas.DataFrame`` class.""" if cls._dataframe is None: - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame cls._dataframe = DataFrame return cls._dataframe diff --git a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py index 3529355b81b..ee782f3cdf3 100644 --- a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py +++ b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py @@ -5,10 +5,9 @@ from collections.abc import Iterator from typing import Any, Callable +import modin.pandas.dataframe as DataFrame import pandas -import snowflake.snowpark.modin.pandas.dataframe as DataFrame - PARTITION_SIZE = 4096 diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index 3986e3d52a9..a48f16992d4 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -78,7 +78,7 @@ def from_non_pandas(df, index, columns, dtype): new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype) if new_qc is not None: - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=new_qc) return new_qc @@ -99,7 +99,7 @@ def from_pandas(df): A new Modin DataFrame object. """ # from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) @@ -118,10 +118,11 @@ def from_arrow(at): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_arrow(at)) @@ -142,10 +143,11 @@ def from_dataframe(df): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df)) @@ -226,7 +228,7 @@ def from_modin_frame_to_mi(df, sortorder=None, names=None): pandas.MultiIndex The pandas.MultiIndex representation of the given DataFrame. """ - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame if isinstance(df, DataFrame): df = df._to_pandas() diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index d3ac525572a..eceb9ca7d7f 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -69,6 +69,7 @@ inherit_modules = [ (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset), + (docstrings.dataframe.DataFrame, modin.pandas.dataframe.DataFrame), (docstrings.series.Series, modin.pandas.series.Series), (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods), ( @@ -90,17 +91,3 @@ snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = ( False ) - - -# TODO: SNOW-1504302: Modin upgrade - use Snowpark pandas DataFrame for isocalendar -# OSS Modin's DatetimeProperties frontend class wraps the returned query compiler with `modin.pandas.DataFrame`. -# Since we currently replace `pd.DataFrame` with our own Snowpark pandas DataFrame object, this causes errors -# since OSS Modin explicitly imports its own DataFrame class here. This override can be removed once the frontend -# DataFrame class is removed from our codebase. -def isocalendar(self): # type: ignore - from snowflake.snowpark.modin.pandas import DataFrame - - return DataFrame(query_compiler=self._query_compiler.dt_isocalendar()) - - -modin.pandas.series_utils.DatetimeProperties.isocalendar = isocalendar diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 01ccad8f430..0005df924db 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Iterable from functools import partial from inspect import getmembers -from types import BuiltinFunctionType +from types import BuiltinFunctionType, MappingProxyType from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union import numpy as np @@ -56,6 +56,7 @@ stddev, stddev_pop, sum as sum_, + trunc, var_pop, variance, when, @@ -65,6 +66,9 @@ OrderedDataFrame, OrderingColumn, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.utils import ( from_pandas_label, pandas_lit, @@ -85,7 +89,7 @@ } -def array_agg_keepna( +def _array_agg_keepna( column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn] ) -> Column: """ @@ -239,62 +243,63 @@ def _columns_coalescing_idxmax_idxmin_helper( ) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding snowflake builtin aggregation function for axis=0. If any change -# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and -# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly. -SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": count, - "mean": mean, - "min": min_, - "max": max_, - "idxmax": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmax" - ), - "idxmin": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmin" - ), - "sum": sum_, - "median": median, - "skew": skew, - "std": stddev, - "var": variance, - "all": builtin("booland_agg"), - "any": builtin("boolor_agg"), - np.max: max_, - np.min: min_, - np.sum: sum_, - np.mean: mean, - np.median: median, - np.std: stddev, - np.var: variance, - "array_agg": array_agg, - "quantile": column_quantile, - "nunique": count_distinct, -} -GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = ( - "min", - "max", - "sum", - "mean", - "median", - "std", - np.max, - np.min, - np.sum, - np.mean, - np.median, - np.std, -) -GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = ( - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", -) +class _SnowparkPandasAggregation(NamedTuple): + """ + A representation of a Snowpark pandas aggregation. + + This structure gives us a common representation for an aggregation that may + have multiple aliases, like "sum" and np.sum. + """ + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool + + # This callable takes a single Snowpark column as input and aggregates the + # column on axis=0. If None, Snowpark pandas does not support this + # aggregation on axis=0. + axis_0_aggregation: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=True, i.e. not including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=True. + axis_1_aggregation_skipna: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=False, i.e. including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=False. + axis_1_aggregation_keepna: Optional[Callable] = None + + +class SnowflakeAggFunc(NamedTuple): + """ + A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType. + """ + + # The aggregation function in Snowpark. + # For aggregation on axis=0, this field should take a single Snowpark + # column and return the aggregated column. + # For aggregation on axis=1, this field should take an arbitrary number + # of Snowpark columns and return the aggregated column. + snowpark_aggregation: Callable + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool class AggFuncWithLabel(NamedTuple): @@ -413,35 +418,143 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: return sum(builtin("zeroifnull")(col) for col in cols) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation -# function may either be a builtin aggregation function, or a function taking in *arg columns -# that then calls the appropriate builtin aggregations. -SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": _columns_count, - "sum": _columns_coalescing_sum, - np.sum: _columns_coalescing_sum, - "min": _columns_coalescing_min, - "max": _columns_coalescing_max, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - np.min: _columns_coalescing_min, - np.max: _columns_coalescing_max, -} +def _create_pandas_to_snowpark_pandas_aggregation_map( + pandas_functions: Iterable[AggFuncTypeBase], + snowpark_pandas_aggregation: _SnowparkPandasAggregation, +) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]: + """ + Create a map from the given pandas functions to the given _SnowparkPandasAggregation. -# These functions are called instead if skipna=False -SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "min": least, - "max": greatest, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark - # sum_, since Snowpark sum_ gets the sum of all rows within a single column. - "sum": lambda *cols: sum(cols), - np.sum: lambda *cols: sum(cols), - np.min: least, - np.max: greatest, -} + Args; + pandas_functions: The pandas functions that map to the given aggregation. + snowpark_pandas_aggregation: The aggregation to map to + + Returns: + The map. + """ + return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions}) + + +# Map between the pandas input aggregation function (str or numpy function) and +# _SnowparkPandasAggregation representing information about applying the +# aggregation in Snowpark pandas. +_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[ + AggFuncTypeBase, _SnowparkPandasAggregation +] = MappingProxyType( + { + "count": _SnowparkPandasAggregation( + axis_0_aggregation=count, + axis_1_aggregation_skipna=_columns_count, + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("mean", np.mean), + _SnowparkPandasAggregation( + axis_0_aggregation=mean, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("min", np.min), + _SnowparkPandasAggregation( + axis_0_aggregation=min_, + axis_1_aggregation_keepna=least, + axis_1_aggregation_skipna=_columns_coalescing_min, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("max", np.max), + _SnowparkPandasAggregation( + axis_0_aggregation=max_, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=_columns_coalescing_max, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("sum", np.sum), + _SnowparkPandasAggregation( + axis_0_aggregation=sum_, + # IMPORTANT: count and sum use python builtin sum to invoke + # __add__ on each column rather than Snowpark sum_, since + # Snowpark sum_ gets the sum of all rows within a single column. + axis_1_aggregation_keepna=lambda *cols: sum(cols), + axis_1_aggregation_skipna=_columns_coalescing_sum, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("median", np.median), + _SnowparkPandasAggregation( + axis_0_aggregation=median, + preserves_snowpark_pandas_types=True, + ), + ), + "idxmax": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "idxmin": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "skew": _SnowparkPandasAggregation( + axis_0_aggregation=skew, + preserves_snowpark_pandas_types=True, + ), + "all": _SnowparkPandasAggregation( + # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("booland_agg")(col(c)), pandas_lit(True) + ), + preserves_snowpark_pandas_types=False, + ), + "any": _SnowparkPandasAggregation( + # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("boolor_agg")(col(c)), pandas_lit(False) + ), + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("std", np.std), + _SnowparkPandasAggregation( + axis_0_aggregation=stddev, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("var", np.var), + _SnowparkPandasAggregation( + axis_0_aggregation=variance, + # variance units are the square of the input column units, so + # variance does not preserve types. + preserves_snowpark_pandas_types=False, + ), + ), + "array_agg": _SnowparkPandasAggregation( + axis_0_aggregation=array_agg, + preserves_snowpark_pandas_types=False, + ), + "quantile": _SnowparkPandasAggregation( + axis_0_aggregation=column_quantile, + preserves_snowpark_pandas_types=True, + ), + "nunique": _SnowparkPandasAggregation( + axis_0_aggregation=count_distinct, + preserves_snowpark_pandas_types=False, + ), + } +) class AggregateColumnOpParameters(NamedTuple): @@ -462,7 +575,7 @@ class AggregateColumnOpParameters(NamedTuple): agg_snowflake_quoted_identifier: str # the snowflake aggregation function to apply on the column - snowflake_agg_func: Callable + snowflake_agg_func: SnowflakeAggFunc # the columns specifying the order of rows in the column. This is only # relevant for aggregations that depend on row order, e.g. summing a string @@ -471,88 +584,108 @@ class AggregateColumnOpParameters(NamedTuple): def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: - return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP + return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION def get_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0 -) -> Optional[Callable]: + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] +) -> Optional[SnowflakeAggFunc]: """ Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. If no corresponding snowflake aggregation function can be found, return None. """ - if axis == 0: - snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func) - if snowflake_agg_func == stddev or snowflake_agg_func == variance: - # for aggregation function std and var, we only support ddof = 0 or ddof = 1. - # when ddof is 1, std is mapped to stddev, var is mapped to variance - # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop - # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 - ddof = agg_kwargs.get("ddof", 1) - if ddof != 1 and ddof != 0: - return None - if ddof == 0: - return stddev_pop if snowflake_agg_func == stddev else var_pop - elif snowflake_agg_func == column_quantile: - interpolation = agg_kwargs.get("interpolation", "linear") - q = agg_kwargs.get("q", 0.5) - if interpolation not in ("linear", "nearest"): - return None - if not is_scalar(q): - # SNOW-1062878 Because list-like q would return multiple rows, calling quantile - # through the aggregate frontend in this manner is unsupported. - return None - return lambda col: column_quantile(col, interpolation, q) - elif agg_func in ("all", "any"): - # If there are no rows in the input frame, the function will also return NULL, which should - # instead by TRUE for "all" and FALSE for "any". - # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name - # as a string literal. - # The generated SQL expression for "all" is - # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE) - # The expression for "any" is - # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE) - default_value = bool(agg_func == "all") - return lambda col: builtin("ifnull")( - # mypy refuses to acknowledge snowflake_agg_func is non-NULL here - snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc] - pandas_lit(default_value), + if axis == 1: + return _generate_rowwise_aggregation_function(agg_func, agg_kwargs) + + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + + if snowpark_pandas_aggregation is None: + # We don't have any implementation at all for this aggregation. + return None + + snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation + + if snowpark_aggregation is None: + # We don't have an implementation on axis=0 for this aggregation. + return None + + # Rewrite some aggregations according to `agg_kwargs.` + if snowpark_aggregation == stddev or snowpark_aggregation == variance: + # for aggregation function std and var, we only support ddof = 0 or ddof = 1. + # when ddof is 1, std is mapped to stddev, var is mapped to variance + # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop + # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 + ddof = agg_kwargs.get("ddof", 1) + if ddof != 1 and ddof != 0: + return None + if ddof == 0: + snowpark_aggregation = ( + stddev_pop if snowpark_aggregation == stddev else var_pop ) - else: - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + elif snowpark_aggregation == column_quantile: + interpolation = agg_kwargs.get("interpolation", "linear") + q = agg_kwargs.get("q", 0.5) + if interpolation not in ("linear", "nearest"): + return None + if not is_scalar(q): + # SNOW-1062878 Because list-like q would return multiple rows, calling quantile + # through the aggregate frontend in this manner is unsupported. + return None + + def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: + return column_quantile(col, interpolation, q) - return snowflake_agg_func + assert ( + snowpark_aggregation is not None + ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." + return SnowflakeAggFunc( + snowpark_aggregation=snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def generate_rowwise_aggregation_function( +def _generate_rowwise_aggregation_function( agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any] -) -> Optional[Callable]: +) -> Optional[SnowflakeAggFunc]: """ Get a callable taking *arg columns to apply for an aggregation. Unlike get_snowflake_agg_func, this function may return a wrapped composition of Snowflake builtin functions depending on the values of the specified kwargs. """ - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) - if not agg_kwargs.get("skipna", True): - snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get( - agg_func, snowflake_agg_func - ) + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + if snowpark_pandas_aggregation is None: + return None + snowpark_aggregation = ( + snowpark_pandas_aggregation.axis_1_aggregation_skipna + if agg_kwargs.get("skipna", True) + else snowpark_pandas_aggregation.axis_1_aggregation_keepna + ) + if snowpark_aggregation is None: + return None min_count = agg_kwargs.get("min_count", 0) if min_count > 0: + original_aggregation = snowpark_aggregation + # Create a case statement to check if the number of non-null values exceeds min_count # when min_count > 0, if the number of not NULL values is < min_count, return NULL. - def agg_func_wrapper(fn: Callable) -> Callable: - return lambda *cols: when( - _columns_count(*cols) < min_count, pandas_lit(None) - ).otherwise(fn(*cols)) + def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: + return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise( + original_aggregation(*cols) + ) - return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func) - return snowflake_agg_func + return SnowflakeAggFunc( + snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def is_supported_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int +def _is_supported_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ check if the aggregation function is supported with snowflake. Current supported @@ -566,12 +699,14 @@ def is_supported_snowflake_agg_func( is_valid: bool. Whether it is valid to implement with snowflake or not. """ if isinstance(agg_func, tuple) and len(agg_func) == 2: + # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, + # take the second part of the named aggregation. agg_func = agg_func[0] return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None -def are_all_agg_funcs_supported_by_snowflake( - agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int +def _are_all_agg_funcs_supported_by_snowflake( + agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ Check if all aggregation functions in the given list are snowflake supported @@ -582,14 +717,14 @@ def are_all_agg_funcs_supported_by_snowflake( return False. """ return all( - is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs ) def check_is_aggregation_supported_in_snowflake( agg_func: AggFuncType, agg_kwargs: dict[str, Any], - axis: int, + axis: Literal[0, 1], ) -> bool: """ check if distributed implementation with snowflake is available for the aggregation @@ -608,18 +743,18 @@ def check_is_aggregation_supported_in_snowflake( if is_dict_like(agg_func): return all( ( - are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) if is_list_like(value) and not is_named_tuple(value) - else is_supported_snowflake_agg_func(value, agg_kwargs, axis) + else _is_supported_snowflake_agg_func(value, agg_kwargs, axis) ) for value in agg_func.values() ) elif is_list_like(agg_func): - return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) - return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) + return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) -def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: +def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: """ Is the given snowflake aggregation function needs to be applied on the numeric column. """ @@ -697,7 +832,7 @@ def drop_non_numeric_data_columns( ) -def generate_aggregation_column( +def _generate_aggregation_column( agg_column_op_params: AggregateColumnOpParameters, agg_kwargs: dict[str, Any], is_groupby_agg: bool, @@ -721,8 +856,14 @@ def generate_aggregation_column( SnowparkColumn after the aggregation function. The column is also aliased back to the original name """ snowpark_column = agg_column_op_params.snowflake_quoted_identifier - snowflake_agg_func = agg_column_op_params.snowflake_agg_func - if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( + snowflake_agg_func = agg_column_op_params.snowflake_agg_func.snowpark_aggregation + + if snowflake_agg_func in (variance, var_pop) and isinstance( + agg_column_op_params.data_type, TimedeltaType + ): + raise TypeError("timedelta64 type does not support var operations") + + if _is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( agg_column_op_params.data_type, BooleanType ): # if the column is a boolean column and the aggregation function requires numeric values, @@ -753,7 +894,7 @@ def generate_aggregation_column( # note that we always assume keepna for array_agg. TODO(SNOW-1040398): # make keepna treatment consistent across array_agg and other # aggregation methods. - agg_snowpark_column = array_agg_keepna( + agg_snowpark_column = _array_agg_keepna( snowpark_column, ordering_columns=agg_column_op_params.ordering_columns ) elif ( @@ -825,6 +966,19 @@ def generate_aggregation_column( ), f"No case expression is constructed with skipna({skipna}), min_count({min_count})" agg_snowpark_column = case_expr.otherwise(agg_snowpark_column) + if ( + isinstance(agg_column_op_params.data_type, TimedeltaType) + and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types + ): + # timedelta aggregations that produce timedelta results might produce + # a decimal type in snowflake, e.g. + # pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in + # Snowflake. We truncate the decimal part of the result, as pandas + # does. + agg_snowpark_column = cast( + trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type + ) + # rename the column to agg_column_quoted_identifier agg_snowpark_column = agg_snowpark_column.as_( agg_column_op_params.agg_snowflake_quoted_identifier @@ -857,7 +1011,7 @@ def aggregate_with_ordered_dataframe( is_groupby_agg = groupby_columns is not None agg_list: list[SnowparkColumn] = [ - generate_aggregation_column( + _generate_aggregation_column( agg_column_op_params=agg_col_op, agg_kwargs=agg_kwargs, is_groupby_agg=is_groupby_agg, @@ -973,7 +1127,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: ) -def generate_pandas_labels_for_agg_result_columns( +def _generate_pandas_labels_for_agg_result_columns( pandas_label: Hashable, num_levels: int, agg_func_list: list[AggFuncInfo], @@ -1102,7 +1256,7 @@ def generate_column_agg_info( ) # generate the pandas label and quoted identifier for the result aggregation columns, one # for each aggregation function to apply. - agg_col_labels = generate_pandas_labels_for_agg_result_columns( + agg_col_labels = _generate_pandas_labels_for_agg_result_columns( pandas_label_to_identifier.pandas_label, num_levels, agg_func_list, # type: ignore[arg-type] diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index b58ba4f50ea..f87cdcd2e47 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -81,7 +81,7 @@ class GroupbyApplySortMethod(Enum): def check_return_variant_and_get_return_type(func: Callable) -> tuple[bool, DataType]: """Check whether the function returns a variant in Snowflake, and get its return type.""" - return_type, _ = get_types_from_type_hints(func, TempObjectType.FUNCTION) + return_type = deduce_return_type_from_function(func) if return_type is None or isinstance( return_type, (VariantType, PandasSeriesType, PandasDataFrameType) ): @@ -390,6 +390,7 @@ def create_udtf_for_groupby_apply( series_groupby: bool, by_types: list[DataType], existing_identifiers: list[str], + force_list_like_to_series: bool = False, ) -> UserDefinedTableFunction: """ Create a UDTF from the Python function for groupby.apply. @@ -480,6 +481,7 @@ def create_udtf_for_groupby_apply( series_groupby: Whether we are performing a SeriesGroupBy.apply() instead of DataFrameGroupBy.apply() by_types: The snowflake types of the by columns. existing_identifiers: List of existing column identifiers; these are omitted when creating new column identifiers. + force_list_like_to_series: Force the function result to series if it is list-like Returns ------- @@ -553,6 +555,17 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def # https://github.com/snowflakedb/snowpandas/pull/823/files#r1507286892 input_object = input_object.infer_objects() func_result = func(input_object, *args, **kwargs) + if ( + force_list_like_to_series + and not isinstance(func_result, native_pd.Series) + and native_pd.api.types.is_list_like(func_result) + ): + if len(func_result) == 1: + func_result = func_result[0] + else: + func_result = native_pd.Series(func_result) + if len(func_result) == len(df.index): + func_result.index = df.index if isinstance(func_result, native_pd.Series): if series_groupby: func_result_as_frame = func_result.to_frame() @@ -754,7 +767,7 @@ def __init__(self) -> None: def convert_numpy_int_result_to_int(value: Any) -> Any: """ - If the result is a numpy int, convert it to a python int. + If the result is a numpy int (or bool), convert it to a python int (or bool.) Use this function to make UDF results JSON-serializable. numpy ints are not JSON-serializable, but python ints are. Note that this function cannot make @@ -772,9 +785,14 @@ def convert_numpy_int_result_to_int(value: Any) -> Any: Returns ------- - int(value) if the value is a numpy int, otherwise the value. + int(value) if the value is a numpy int, + bool(value) if the value is a numpy bool, otherwise the value. """ - return int(value) if np.issubdtype(type(value), np.integer) else value + return ( + int(value) + if np.issubdtype(type(value), np.integer) + else (bool(value) if np.issubdtype(type(value), np.bool_) else value) + ) def deduce_return_type_from_function( @@ -887,7 +905,7 @@ def get_metadata_from_groupby_apply_pivot_result_column_names( input: get_metadata_from_groupby_apply_pivot_result_column_names([ - # this representa a data column named ('a', 'group_key') at position 0 + # this represents a data column named ('a', 'group_key') at position 0 '"\'{""0"": ""a"", ""1"": ""group_key"", ""data_pos"": 0, ""names"": [""c1"", ""c2""]}\'"', # this represents a data column named ('b', 'int_col') at position 1 '"\'{""0"": ""b"", ""1"": ""int_col"", ""data_pos"": 1, ""names"": [""c1"", ""c2""]}\'"', @@ -1110,7 +1128,9 @@ def groupby_apply_pivot_result_to_final_ordered_dataframe( # in GROUP_KEY_APPEARANCE_ORDER) and assign the # label i to all rows that came from func(group_i). [ - original_row_position_snowflake_quoted_identifier + col(original_row_position_snowflake_quoted_identifier).as_( + new_index_identifier + ) if sort_method is GroupbyApplySortMethod.ORIGINAL_ROW_ORDER else ( dense_rank().over( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py index 1aa81b36e64..475fbfcefa7 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py @@ -185,19 +185,6 @@ def compute_power_between_snowpark_columns( return result -def is_binary_op_supported(op: str) -> bool: - """ - check whether binary operation is mappable to Snowflake - Args - op: op as string - - Returns: - True if binary operation can be mapped to Snowflake/Snowpark, else False - """ - - return op in SUPPORTED_BINARY_OPERATIONS - - def _compute_subtraction_between_snowpark_timestamp_columns( first_operand: SnowparkColumn, first_datatype: DataType, @@ -312,314 +299,527 @@ def _op_is_between_timedelta_and_numeric( ) -def compute_binary_op_between_snowpark_columns( - op: str, - first_operand: SnowparkColumn, - first_datatype: DataTypeGetter, - second_operand: SnowparkColumn, - second_datatype: DataTypeGetter, -) -> SnowparkPandasColumn: - """ - Compute pandas binary operation for two SnowparkColumns - Args: - op: pandas operation - first_operand: SnowparkColumn for lhs - first_datatype: Callable for Snowpark Datatype for lhs - second_operand: SnowparkColumn for rhs - second_datatype: Callable for Snowpark DateType for rhs - it is not needed. +class BinaryOp: + def __init__( + self, + op: str, + first_operand: SnowparkColumn, + first_datatype: DataTypeGetter, + second_operand: SnowparkColumn, + second_datatype: DataTypeGetter, + ) -> None: + """ + Construct a BinaryOp object to compute pandas binary operation for two SnowparkColumns + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs + it is not needed. + """ + self.op = op + self.first_operand = first_operand + self.first_datatype = first_datatype + self.second_operand = second_operand + self.second_datatype = second_datatype + self.result_column = None + self.result_snowpark_pandas_type = None + + @staticmethod + def is_binary_op_supported(op: str) -> bool: + """ + check whether binary operation is mappable to Snowflake + Args + op: op as string + + Returns: + True if binary operation can be mapped to Snowflake/Snowpark, else False + """ + + return op in SUPPORTED_BINARY_OPERATIONS + + @staticmethod + def create( + op: str, + first_operand: SnowparkColumn, + first_datatype: DataTypeGetter, + second_operand: SnowparkColumn, + second_datatype: DataTypeGetter, + ) -> "BinaryOp": + """ + Create a BinaryOp object to compute pandas binary operation for two SnowparkColumns + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs + it is not needed. + """ + + def snake_to_camel(snake_str: str) -> str: + """Converts a snake case string to camel case.""" + components = snake_str.split("_") + return "".join(x.title() for x in components) + + if op in _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP: + # Normalize right-sided binary operations to the equivalent left-sided + # operations with swapped operands. For example, rsub(col(a), col(b)) + # becomes sub(col(b), col(a)) + op, first_operand, first_datatype, second_operand, second_datatype = ( + _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP[op], + second_operand, + second_datatype, + first_operand, + first_datatype, + ) - Returns: - SnowparkPandasColumn for translated pandas operation - """ - if op in _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP: - # Normalize right-sided binary operations to the equivalent left-sided - # operations with swapped operands. For example, rsub(col(a), col(b)) - # becomes sub(col(b), col(a)) - op, first_operand, first_datatype, second_operand, second_datatype = ( - _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP[op], - second_operand, - second_datatype, - first_operand, - first_datatype, + class_name = f"{snake_to_camel(op)}Op" + op_class = None + for subclass in BinaryOp.__subclasses__(): + if subclass.__name__ == class_name: + op_class = subclass + if op_class is None: + op_class = BinaryOp + return op_class( + op, first_operand, first_datatype, second_operand, second_datatype ) - binary_op_result_column = None - snowpark_pandas_type = None + @staticmethod + def create_with_fill_value( + op: str, + lhs: SnowparkColumn, + lhs_datatype: DataTypeGetter, + rhs: SnowparkColumn, + rhs_datatype: DataTypeGetter, + fill_value: Scalar, + ) -> "BinaryOp": + """ + Create a BinaryOp object to compute pandas binary operation for two SnowparkColumns with fill value for missing + values. + + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs + it is not needed. + fill_value: the value to fill missing values + + Helper method for performing binary operations. + 1. Fills NaN/None values in the lhs and rhs with the given fill_value. + 2. Computes the binary operation expression for lhs rhs. + + fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs. + For instance, with fill_value = 100, + 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value. + result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110 + 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value. + result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103 + 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None. + result = lhs + rhs => None + None => None. + + Args: + op: pandas operation to perform between lhs and rhs + lhs: the lhs SnowparkColumn + lhs_datatype: Callable for Snowpark Datatype for lhs + rhs: the rhs SnowparkColumn + rhs_datatype: Callable for Snowpark Datatype for rhs + fill_value: Fill existing missing (NaN) values, and any new element needed for + successful DataFrame alignment, with this value before computation. + + Returns: + SnowparkPandasColumn for translated pandas operation + """ + lhs_cond, rhs_cond = lhs, rhs + if fill_value is not None: + fill_value_lit = pandas_lit(fill_value) + lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs) + rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs) + + return BinaryOp.create(op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype) + + @staticmethod + def create_with_rhs_scalar( + op: str, + first_operand: SnowparkColumn, + datatype: DataTypeGetter, + second_operand: Scalar, + ) -> "BinaryOp": + """ + Compute the binary operation between a Snowpark column and a scalar. + Args: + op: the name of binary operation + first_operand: The SnowparkColumn for lhs + datatype: Callable for Snowpark data type + second_operand: Scalar value + + Returns: + SnowparkPandasColumn for translated pandas operation + """ + + def second_datatype() -> DataType: + return infer_object_type(second_operand) + + return BinaryOp.create( + op, first_operand, datatype, pandas_lit(second_operand), second_datatype + ) - # some operators and the data types have to be handled specially to align with pandas - # However, it is difficult to fail early if the arithmetic operator is not compatible - # with the data type, so we just let the server raise exception (e.g. a string minus a string). - if ( - op == "add" - and isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), TimestampType) - ): - binary_op_result_column = dateadd("ns", second_operand, first_operand) - elif ( - op == "add" - and isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), TimestampType) - ): - binary_op_result_column = dateadd("ns", first_operand, second_operand) - elif op in ( - "add", - "sub", - "eq", - "ne", - "gt", - "ge", - "lt", - "le", - "floordiv", - "truediv", - ) and ( - ( - isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), NullType) + @staticmethod + def create_with_lhs_scalar( + op: str, + first_operand: Scalar, + second_operand: SnowparkColumn, + datatype: DataTypeGetter, + ) -> "BinaryOp": + """ + Compute the binary operation between a scalar and a Snowpark column. + Args: + op: the name of binary operation + first_operand: Scalar value + second_operand: The SnowparkColumn for rhs + datatype: Callable for Snowpark data type + it is not needed. + + Returns: + SnowparkPandasColumn for translated pandas operation + """ + + def first_datatype() -> DataType: + return infer_object_type(first_operand) + + return BinaryOp.create( + op, pandas_lit(first_operand), first_datatype, second_operand, datatype ) - or ( - isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), NullType) + + def _custom_compute(self) -> None: + """Implement custom compute method if needed.""" + pass + + def _get_result(self) -> SnowparkPandasColumn: + return SnowparkPandasColumn( + snowpark_column=self.result_column, + snowpark_pandas_type=self.result_snowpark_pandas_type, ) - ): - return SnowparkPandasColumn(pandas_lit(None), TimedeltaType()) - elif ( - op == "sub" - and isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), TimestampType) - ): - binary_op_result_column = dateadd("ns", -1 * second_operand, first_operand) - elif ( - op == "sub" - and isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), TimestampType) - ): + + def _check_timedelta_with_none(self) -> None: + if self.op in ( + "add", + "sub", + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + "floordiv", + "truediv", + ) and ( + ( + isinstance(self.first_datatype(), TimedeltaType) + and isinstance(self.second_datatype(), NullType) + ) + or ( + isinstance(self.second_datatype(), TimedeltaType) + and isinstance(self.first_datatype(), NullType) + ) + ): + self.result_column = pandas_lit(None) + self.result_snowpark_pandas_type = TimedeltaType() + + def _check_error(self) -> None: # Timedelta - Timestamp doesn't make sense. Raise the same error # message as pandas. - raise TypeError("bad operand type for unary -: 'DatetimeArray'") - elif op == "mod" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - binary_op_result_column = compute_modulo_between_snowpark_columns( - first_operand, first_datatype(), second_operand, second_datatype() - ) - snowpark_pandas_type = TimedeltaType() - elif op == "pow" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise TypeError("unsupported operand type for **: Timedelta") - elif op == "__or__" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise TypeError("unsupported operand type for |: Timedelta") - elif op == "__and__" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise TypeError("unsupported operand type for &: Timedelta") - elif ( - op in ("add", "sub") - and isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), TimedeltaType) - ): - snowpark_pandas_type = TimedeltaType() - elif op == "mul" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined] - np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]")) - ) - elif op in ( - "eq", - "ne", - "gt", - "ge", - "lt", - "le", - ) and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - # These operations, when done between timedeltas, work without any - # extra handling in `snowpark_pandas_type` or `binary_op_result_column`. - pass - elif op == "mul" and ( - _op_is_between_timedelta_and_numeric(first_datatype, second_datatype) - ): - binary_op_result_column = cast( - floor(first_operand * second_operand), LongType() - ) - snowpark_pandas_type = TimedeltaType() - # For `eq` and `ne`, note that Snowflake will consider 1 equal to - # Timedelta(1) because those two have the same representation in Snowflake, - # so we have to compare types in the client. - elif op == "eq" and ( - _op_is_between_timedelta_and_numeric(first_datatype, second_datatype) - ): - binary_op_result_column = pandas_lit(False) - elif op == "ne" and _op_is_between_timedelta_and_numeric( - first_datatype, second_datatype - ): - binary_op_result_column = pandas_lit(True) - elif ( - op in ("truediv", "floordiv") - and isinstance(first_datatype(), TimedeltaType) - and _is_numeric_non_timedelta_type(second_datatype()) - ): - binary_op_result_column = cast( - floor(first_operand / second_operand), LongType() - ) - snowpark_pandas_type = TimedeltaType() - elif ( - op == "mod" - and isinstance(first_datatype(), TimedeltaType) - and _is_numeric_non_timedelta_type(second_datatype()) - ): - binary_op_result_column = ceil( - compute_modulo_between_snowpark_columns( - first_operand, first_datatype(), second_operand, second_datatype() + if ( + self.op == "sub" + and isinstance(self.first_datatype(), TimedeltaType) + and isinstance(self.second_datatype(), TimestampType) + ): + raise TypeError("bad operand type for unary -: 'DatetimeArray'") + + # Raise error for two timedelta or timedelta and null + two_timedeltas_or_timedelta_and_null_error = { + "pow": TypeError("unsupported operand type for **: Timedelta"), + "__or__": TypeError("unsupported operand type for |: Timedelta"), + "__and__": TypeError("unsupported operand type for &: Timedelta"), + "mul": np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined] + np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]")) + ), + } + if ( + self.op in two_timedeltas_or_timedelta_and_null_error + and _op_is_between_two_timedeltas_or_timedelta_and_null( + self.first_datatype(), self.second_datatype() ) - ) - snowpark_pandas_type = TimedeltaType() - elif op in ("add", "sub") and ( - ( - isinstance(first_datatype(), TimedeltaType) - and _is_numeric_non_timedelta_type(second_datatype()) - ) - or ( - _is_numeric_non_timedelta_type(first_datatype()) - and isinstance(second_datatype(), TimedeltaType) - ) - ): - raise TypeError( - "Snowpark pandas does not support addition or subtraction between timedelta values and numeric values." - ) - elif op in ("truediv", "floordiv", "mod") and ( - _is_numeric_non_timedelta_type(first_datatype()) - and isinstance(second_datatype(), TimedeltaType) - ): - raise TypeError( - "Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), or floordiv (//)." - ) - elif op in ( - "add", - "sub", - "truediv", - "floordiv", - "mod", - "gt", - "ge", - "lt", - "le", - "ne", - "eq", - ) and ( - ( - isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), StringType) - ) - or ( - isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), StringType) - ) - ): + ): + raise two_timedeltas_or_timedelta_and_null_error[self.op] + + if self.op in ("add", "sub") and ( + ( + isinstance(self.first_datatype(), TimedeltaType) + and _is_numeric_non_timedelta_type(self.second_datatype()) + ) + or ( + _is_numeric_non_timedelta_type(self.first_datatype()) + and isinstance(self.second_datatype(), TimedeltaType) + ) + ): + raise TypeError( + "Snowpark pandas does not support addition or subtraction between timedelta values and numeric values." + ) + + if self.op in ("truediv", "floordiv", "mod") and ( + _is_numeric_non_timedelta_type(self.first_datatype()) + and isinstance(self.second_datatype(), TimedeltaType) + ): + raise TypeError( + "Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), " + "or floordiv (//)." + ) + # TODO(SNOW-1646604): Support these cases. - ErrorMessage.not_implemented( - f"Snowpark pandas does not yet support the operation {op} between timedelta and string" - ) - elif op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and ( - _op_is_between_timedelta_and_numeric(first_datatype, second_datatype) - ): - raise TypeError( - f"Snowpark pandas does not support binary operation {op} between timedelta and a non-timedelta type." - ) - elif op == "floordiv": - binary_op_result_column = floor(first_operand / second_operand) - elif op == "mod": - binary_op_result_column = compute_modulo_between_snowpark_columns( - first_operand, first_datatype(), second_operand, second_datatype() - ) - elif op == "pow": - binary_op_result_column = compute_power_between_snowpark_columns( - first_operand, second_operand - ) - elif op == "__or__": - binary_op_result_column = first_operand | second_operand - elif op == "__and__": - binary_op_result_column = first_operand & second_operand - elif ( - op == "add" - and isinstance(second_datatype(), StringType) - and isinstance(first_datatype(), StringType) - ): - # string/string case (only for add) - binary_op_result_column = concat(first_operand, second_operand) - elif op == "mul" and ( - ( - isinstance(second_datatype(), _IntegralType) - and isinstance(first_datatype(), StringType) - ) - or ( - isinstance(second_datatype(), StringType) - and isinstance(first_datatype(), _IntegralType) + if self.op in ( + "add", + "sub", + "truediv", + "floordiv", + "mod", + "gt", + "ge", + "lt", + "le", + "ne", + "eq", + ) and ( + ( + isinstance(self.first_datatype(), TimedeltaType) + and isinstance(self.second_datatype(), StringType) + ) + or ( + isinstance(self.second_datatype(), TimedeltaType) + and isinstance(self.first_datatype(), StringType) + ) + ): + ErrorMessage.not_implemented( + f"Snowpark pandas does not yet support the operation {self.op} between timedelta and string" + ) + + if self.op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and ( + _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ) + ): + raise TypeError( + f"Snowpark pandas does not support binary operation {self.op} between timedelta and a non-timedelta " + f"type." + ) + + def compute(self) -> SnowparkPandasColumn: + self._check_error() + + self._check_timedelta_with_none() + + if self.result_column is not None: + return self._get_result() + + # Generally, some operators and the data types have to be handled specially to align with pandas + # However, it is difficult to fail early if the arithmetic operator is not compatible + # with the data type, so we just let the server raise exception (e.g. a string minus a string). + + self._custom_compute() + if self.result_column is None: + # If there is no special binary_op_result_column result, it means the operator and + # the data type of the column don't need special handling. Then we get the overloaded + # operator from Snowpark Column class, e.g., __add__ to perform binary operations. + self.result_column = getattr(self.first_operand, f"__{self.op}__")( + self.second_operand + ) + + return self._get_result() + + +class AddOp(BinaryOp): + def _custom_compute(self) -> None: + if isinstance(self.second_datatype(), TimedeltaType) and isinstance( + self.first_datatype(), TimestampType + ): + self.result_column = dateadd("ns", self.second_operand, self.first_operand) + elif isinstance(self.first_datatype(), TimedeltaType) and isinstance( + self.second_datatype(), TimestampType + ): + self.result_column = dateadd("ns", self.first_operand, self.second_operand) + elif isinstance(self.first_datatype(), TimedeltaType) and isinstance( + self.second_datatype(), TimedeltaType + ): + self.result_snowpark_pandas_type = TimedeltaType() + elif isinstance(self.second_datatype(), StringType) and isinstance( + self.first_datatype(), StringType + ): + # string/string case (only for add) + self.result_column = concat(self.first_operand, self.second_operand) + + +class SubOp(BinaryOp): + def _custom_compute(self) -> None: + if isinstance(self.second_datatype(), TimedeltaType) and isinstance( + self.first_datatype(), TimestampType + ): + self.result_column = dateadd( + "ns", -1 * self.second_operand, self.first_operand + ) + elif isinstance(self.first_datatype(), TimedeltaType) and isinstance( + self.second_datatype(), TimedeltaType + ): + self.result_snowpark_pandas_type = TimedeltaType() + elif isinstance(self.first_datatype(), TimestampType) and isinstance( + self.second_datatype(), NullType + ): + # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, + # but it's valid in pandas and returns NULL. + self.result_column = pandas_lit(None) + elif isinstance(self.first_datatype(), NullType) and isinstance( + self.second_datatype(), TimestampType + ): + # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, + # but it's valid in pandas and returns NULL. + self.result_column = pandas_lit(None) + elif isinstance(self.first_datatype(), TimestampType) and isinstance( + self.second_datatype(), TimestampType + ): + ( + self.result_column, + self.result_snowpark_pandas_type, + ) = _compute_subtraction_between_snowpark_timestamp_columns( + first_operand=self.first_operand, + first_datatype=self.first_datatype(), + second_operand=self.second_operand, + second_datatype=self.second_datatype(), + ) + + +class ModOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = compute_modulo_between_snowpark_columns( + self.first_operand, + self.first_datatype(), + self.second_operand, + self.second_datatype(), ) - ): - # string/integer case (only for mul/rmul). - # swap first_operand with second_operand because - # REPEAT(, ) expects to be string - if isinstance(first_datatype(), _IntegralType): - first_operand, second_operand = second_operand, first_operand - - binary_op_result_column = iff( - second_operand > pandas_lit(0), - repeat(first_operand, second_operand), - # Snowflake's repeat doesn't support negative number, - # but pandas will return an empty string - pandas_lit(""), + if _op_is_between_two_timedeltas_or_timedelta_and_null( + self.first_datatype(), self.second_datatype() + ): + self.result_snowpark_pandas_type = TimedeltaType() + elif isinstance( + self.first_datatype(), TimedeltaType + ) and _is_numeric_non_timedelta_type(self.second_datatype()): + self.result_column = ceil(self.result_column) + self.result_snowpark_pandas_type = TimedeltaType() + + +class MulOp(BinaryOp): + def _custom_compute(self) -> None: + if _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ): + self.result_column = cast( + floor(self.first_operand * self.second_operand), LongType() + ) + self.result_snowpark_pandas_type = TimedeltaType() + elif ( + isinstance(self.second_datatype(), _IntegralType) + and isinstance(self.first_datatype(), StringType) + ) or ( + isinstance(self.second_datatype(), StringType) + and isinstance(self.first_datatype(), _IntegralType) + ): + # string/integer case (only for mul/rmul). + # swap first_operand with second_operand because + # REPEAT(, ) expects to be string + if isinstance(self.first_datatype(), _IntegralType): + self.first_operand, self.second_operand = ( + self.second_operand, + self.first_operand, + ) + + self.result_column = iff( + self.second_operand > pandas_lit(0), + repeat(self.first_operand, self.second_operand), + # Snowflake's repeat doesn't support negative number, + # but pandas will return an empty string + pandas_lit(""), + ) + + +class EqOp(BinaryOp): + def _custom_compute(self) -> None: + # For `eq` and `ne`, note that Snowflake will consider 1 equal to + # Timedelta(1) because those two have the same representation in Snowflake, + # so we have to compare types in the client. + if _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ): + self.result_column = pandas_lit(False) + + +class NeOp(BinaryOp): + def _custom_compute(self) -> None: + # For `eq` and `ne`, note that Snowflake will consider 1 equal to + # Timedelta(1) because those two have the same representation in Snowflake, + # so we have to compare types in the client. + if _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ): + self.result_column = pandas_lit(True) + + +class FloordivOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = floor(self.first_operand / self.second_operand) + if isinstance( + self.first_datatype(), TimedeltaType + ) and _is_numeric_non_timedelta_type(self.second_datatype()): + self.result_column = cast(self.result_column, LongType()) + self.result_snowpark_pandas_type = TimedeltaType() + + +class TruedivOp(BinaryOp): + def _custom_compute(self) -> None: + if isinstance( + self.first_datatype(), TimedeltaType + ) and _is_numeric_non_timedelta_type(self.second_datatype()): + self.result_column = cast( + floor(self.first_operand / self.second_operand), LongType() + ) + self.result_snowpark_pandas_type = TimedeltaType() + + +class PowOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = compute_power_between_snowpark_columns( + self.first_operand, self.second_operand ) - elif op == "equal_null": + + +class OrOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = self.first_operand | self.second_operand + + +class AndOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = self.first_operand & self.second_operand + + +class EqualNullOp(BinaryOp): + def _custom_compute(self) -> None: # TODO(SNOW-1641716): In Snowpark pandas, generally use this equal_null # with type checking intead of snowflake.snowpark.functions.equal_null. - if not are_equal_types(first_datatype(), second_datatype()): - binary_op_result_column = pandas_lit(False) + if not are_equal_types(self.first_datatype(), self.second_datatype()): + self.result_column = pandas_lit(False) else: - binary_op_result_column = first_operand.equal_null(second_operand) - elif ( - op == "sub" - and isinstance(first_datatype(), TimestampType) - and isinstance(second_datatype(), NullType) - ): - # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, - # but it's valid in pandas and returns NULL. - binary_op_result_column = pandas_lit(None) - elif ( - op == "sub" - and isinstance(first_datatype(), NullType) - and isinstance(second_datatype(), TimestampType) - ): - # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, - # but it's valid in pandas and returns NULL. - binary_op_result_column = pandas_lit(None) - elif ( - op == "sub" - and isinstance(first_datatype(), TimestampType) - and isinstance(second_datatype(), TimestampType) - ): - return _compute_subtraction_between_snowpark_timestamp_columns( - first_operand=first_operand, - first_datatype=first_datatype(), - second_operand=second_operand, - second_datatype=second_datatype(), - ) - # If there is no special binary_op_result_column result, it means the operator and - # the data type of the column don't need special handling. Then we get the overloaded - # operator from Snowpark Column class, e.g., __add__ to perform binary operations. - if binary_op_result_column is None: - binary_op_result_column = getattr(first_operand, f"__{op}__")(second_operand) - - return SnowparkPandasColumn( - snowpark_column=binary_op_result_column, - snowpark_pandas_type=snowpark_pandas_type, - ) + self.result_column = self.first_operand.equal_null(self.second_operand) def are_equal_types(type1: DataType, type2: DataType) -> bool: @@ -644,104 +844,6 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool: return type1 == type2 -def compute_binary_op_between_snowpark_column_and_scalar( - op: str, - first_operand: SnowparkColumn, - datatype: DataTypeGetter, - second_operand: Scalar, -) -> SnowparkPandasColumn: - """ - Compute the binary operation between a Snowpark column and a scalar. - Args: - op: the name of binary operation - first_operand: The SnowparkColumn for lhs - datatype: Callable for Snowpark data type - second_operand: Scalar value - - Returns: - SnowparkPandasColumn for translated pandas operation - """ - - def second_datatype() -> DataType: - return infer_object_type(second_operand) - - return compute_binary_op_between_snowpark_columns( - op, first_operand, datatype, pandas_lit(second_operand), second_datatype - ) - - -def compute_binary_op_between_scalar_and_snowpark_column( - op: str, - first_operand: Scalar, - second_operand: SnowparkColumn, - datatype: DataTypeGetter, -) -> SnowparkPandasColumn: - """ - Compute the binary operation between a scalar and a Snowpark column. - Args: - op: the name of binary operation - first_operand: Scalar value - second_operand: The SnowparkColumn for rhs - datatype: Callable for Snowpark data type - it is not needed. - - Returns: - SnowparkPandasColumn for translated pandas operation - """ - - def first_datatype() -> DataType: - return infer_object_type(first_operand) - - return compute_binary_op_between_snowpark_columns( - op, pandas_lit(first_operand), first_datatype, second_operand, datatype - ) - - -def compute_binary_op_with_fill_value( - op: str, - lhs: SnowparkColumn, - lhs_datatype: DataTypeGetter, - rhs: SnowparkColumn, - rhs_datatype: DataTypeGetter, - fill_value: Scalar, -) -> SnowparkPandasColumn: - """ - Helper method for performing binary operations. - 1. Fills NaN/None values in the lhs and rhs with the given fill_value. - 2. Computes the binary operation expression for lhs rhs. - - fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs. - For instance, with fill_value = 100, - 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value. - result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110 - 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value. - result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103 - 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None. - result = lhs + rhs => None + None => None. - - Args: - op: pandas operation to perform between lhs and rhs - lhs: the lhs SnowparkColumn - lhs_datatype: Callable for Snowpark Datatype for lhs - rhs: the rhs SnowparkColumn - rhs_datatype: Callable for Snowpark Datatype for rhs - fill_value: Fill existing missing (NaN) values, and any new element needed for - successful DataFrame alignment, with this value before computation. - - Returns: - SnowparkPandasColumn for translated pandas operation - """ - lhs_cond, rhs_cond = lhs, rhs - if fill_value is not None: - fill_value_lit = pandas_lit(fill_value) - lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs) - rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs) - - return compute_binary_op_between_snowpark_columns( - op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype - ) - - def merge_label_and_identifier_pairs( sorted_column_labels: list[str], q_frame_sorted: list[tuple[str, str]], diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py index 882dc79d2a8..4eaf98d9b29 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -189,8 +189,6 @@ def compute_bin_indices( values_frame, cuts_frame, how="asof", - left_on=[], - right_on=[], left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0], right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0], match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index c2c224e404c..6207bd2399a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame( key, count_frame, "cross", - left_on=[], - right_on=[], inherit_join_index=InheritJoinIndex.FROM_LEFT, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 457bd388f2b..d07211dbcf5 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple): result_column_mapper: JoinOrAlignResultColumnMapper +def assert_snowpark_pandas_types_match( + left: InternalFrame, + right: InternalFrame, + left_join_identifiers: list[str], + right_join_identifiers: list[str], +) -> None: + """ + If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised. + + Args: + left: An internal frame to use on left side of join. + right: An internal frame to use on right side of join. + left_join_identifiers: List of snowflake identifiers to check types from 'left' frame. + right_join_identifiers: List of snowflake identifiers to check types from 'right' frame. + left_identifiers and right_identifiers must be lists of equal length. + + Returns: None + + Raises: ValueError + """ + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in left_join_identifiers + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in right_join_identifiers + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_join_identifiers[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = ( + rt + if rt is not None + else right.get_snowflake_type(right_join_identifiers[i]) + ) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + def join( left: InternalFrame, right: InternalFrame, how: JoinTypeLit, - left_on: list[str], - right_on: list[str], + left_on: Optional[list[str]] = None, + right_on: Optional[list[str]] = None, left_match_col: Optional[str] = None, right_match_col: Optional[str] = None, match_comparator: Optional[MatchComparator] = None, @@ -161,40 +206,48 @@ def join( include mapping for index + data columns, ordering columns and row position column if exists. """ - assert len(left_on) == len( - right_on - ), "left_on and right_on must be of same length or both be None" - if join_key_coalesce_config is not None: - assert len(join_key_coalesce_config) == len( - left_on - ), "join_key_coalesce_config must be of same length as left_on and right_on" assert how in get_args( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" - def assert_snowpark_pandas_types_match() -> None: - """If Snowpark pandas types do not match, then a ValueError will be raised.""" - left_types = [ - left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in left_on - ] - right_types = [ - right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in right_on - ] - for i, (lt, rt) in enumerate(zip(left_types, right_types)): - if lt != rt: - left_on_id = left_on[i] - idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) - key = left.data_column_pandas_labels[idx] - lt = lt if lt is not None else left.get_snowflake_type(left_on_id) - rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) - raise ValueError( - f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " - f"If you wish to proceed you should use pd.concat" - ) + left_on = left_on or [] + right_on = right_on or [] + assert len(left_on) == len( + right_on + ), "left_on and right_on must be of same length or both be None" - assert_snowpark_pandas_types_match() + if how == "asof": + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + left_join_key = [left_match_col] + right_join_key = [right_match_col] + left_join_key.extend(left_on) + right_join_key.extend(right_on) + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key" + else: + left_join_key = left_on + right_join_key = right_on + assert ( + left_match_col is None + and right_match_col is None + and match_comparator is None + ), f"match condition should not be provided for {how} join" + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "join_key_coalesce_config must be of same length as left_on and right_on" + + assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key) # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None: match_comparator=match_comparator, how=how, ) - return _create_internal_frame_with_join_or_align_result( joined_ordered_dataframe, left, right, how, - left_on, - right_on, + left_join_key, + right_join_key, sort, join_key_coalesce_config, inherit_join_index, @@ -1075,7 +1127,7 @@ def join_on_index_columns( Returns: An InternalFrame for the joined result. - A JoinOrAlignResultColumnMapper that provides quited identifiers mapping from the + A JoinOrAlignResultColumnMapper that provides quoted identifiers mapping from the original left and right dataframe to the joined dataframe, it is guaranteed to include mapping for index + data columns, ordering columns and row position column if exists. @@ -1263,7 +1315,7 @@ def align_on_index( * outer: use union of index from both frames, sort index lexicographically. Returns: An InternalFrame for the aligned result. - A JoinOrAlignResultColumnMapper that provides quited identifiers mapping from the + A JoinOrAlignResultColumnMapper that provides quoted identifiers mapping from the original left and right dataframe to the aligned dataframe, it is guaranteed to include mapping for index + data columns, ordering columns and row position column if exists. @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None: ) elif self._how == "right": ordering_column_identifiers = mapped_right_on + elif self._how == "asof": + # Order only by the left match_condition column + ordering_column_identifiers = [mapped_left_on[0]] else: # left join, inner join, left align, coalesce align ordering_column_identifiers = mapped_left_on diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index f7ae87c2a5d..91537d98e30 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -1197,22 +1197,29 @@ def join( # get the new mapped right on identifier right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols] - # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' - on = None - for left_col, right_col in zip(left_on_cols, right_on_cols): - eq = Column(left_col).equal_null(Column(right_col)) - on = eq if on is None else on & eq - if how == "asof": - assert left_match_col, "left_match_col was not provided to ASOF Join" + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" left_match_col = Column(left_match_col) # Get the new mapped right match condition identifier - assert right_match_col, "right_match_col was not provided to ASOF Join" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" right_match_col = Column(right_identifiers_rename_map[right_match_col]) # ASOF Join requires the use of match_condition - assert match_comparator, "match_comparator was not provided to ASOF Join" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).__eq__(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right=right_snowpark_dataframe_ref.snowpark_dataframe, + on=on, how=how, match_condition=getattr(left_match_col, match_comparator.value)( right_match_col @@ -1224,6 +1231,12 @@ def join( right_snowpark_dataframe_ref.snowpark_dataframe, how=how ) else: + # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).equal_null(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right_snowpark_dataframe_ref.snowpark_dataframe, on, how ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 3bf1062107e..e7a96b49ef1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -520,12 +520,15 @@ def single_pivot_helper( data_column_snowflake_quoted_identifiers: new data column snowflake quoted identifiers this pivot result data_column_pandas_labels: new data column pandas labels for this pivot result """ - snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {}) - if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func): + snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0) + if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func( + snowflake_agg_func.snowpark_aggregation + ): # TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations raise ErrorMessage.not_implemented( f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments." ) + snowpark_aggr_func = snowflake_agg_func.snowpark_aggregation pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair @@ -1231,17 +1234,19 @@ def get_margin_aggregation( Returns: Snowpark column expression for the aggregation function result. """ - resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}) + resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}, axis=0) # This would have been resolved during the original pivot at an early stage. assert resolved_aggfunc is not None, "resolved_aggfunc is None" - aggfunc_expr = resolved_aggfunc(snowflake_quoted_identifier) + aggregation_expression = resolved_aggfunc.snowpark_aggregation( + snowflake_quoted_identifier + ) - if resolved_aggfunc == sum_: - aggfunc_expr = coalesce(aggfunc_expr, pandas_lit(0)) + if resolved_aggfunc.snowpark_aggregation == sum_: + aggregation_expression = coalesce(aggregation_expression, pandas_lit(0)) - return aggfunc_expr + return aggregation_expression def expand_pivot_result_with_pivot_table_margins_no_groupby_columns( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index de83e0429bf..ba8ceedec5e 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -649,8 +649,6 @@ def perform_asof_join_on_frame( left=preserving_frame, right=referenced_frame, how="asof", - left_on=[], - right_on=[], left_match_col=left_timecol_snowflake_quoted_identifier, right_match_col=right_timecol_snowflake_quoted_identifier, match_comparator=( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index d38584c14de..e19a6de37ba 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -567,9 +567,7 @@ def __new__( attrs (Dict[str, Any]): The attributes of the class. Returns: - Union[snowflake.snowpark.modin.pandas.series.Series, - snowflake.snowpark.modin.pandas.dataframe.DataFrame, - snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, + Union[snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, snowflake.snowpark.modin.pandas.resample.Resampler, snowflake.snowpark.modin.pandas.window.Window, snowflake.snowpark.modin.pandas.window.Rolling]: diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index 0242177d1f0..3b714087535 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -22,9 +22,17 @@ cast, convert_timezone, date_part, + dayofmonth, + hour, iff, + minute, + month, + second, + timestamp_tz_from_parts, to_decimal, + to_timestamp_ntz, trunc, + year, ) from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage @@ -467,3 +475,60 @@ def convert_dateoffset_to_interval( ) interval_kwargs[new_param] = offset return Interval(**interval_kwargs) + + +def tz_localize_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: + """ + Localize tz-naive to tz-aware. + Args: + tz : str, pytz.timezone, optional + Localize a tz-naive datetime column to tz-aware + + Args: + column: the Snowpark datetime column + tz: time zone for time. Corresponding timestamps would be converted to this time zone of the Datetime Array/Index. A tz of None will convert to UTC and remove the timezone information. + + Returns: + The column after tz localization + """ + if tz is None: + # If this column is already a TIMESTAMP_NTZ, this cast does nothing. + # If the column is a TIMESTAMP_TZ, the cast drops the timezone and converts + # to TIMESTAMP_NTZ. + return to_timestamp_ntz(column) + else: + if isinstance(tz, dt.tzinfo): + tz_name = tz.tzname(None) + else: + tz_name = tz + return timestamp_tz_from_parts( + year(column), + month(column), + dayofmonth(column), + hour(column), + minute(column), + second(column), + date_part("nanosecond", column), + pandas_lit(tz_name), + ) + + +def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: + """ + Converts a datetime column to the specified timezone + + Args: + column: the Snowpark datetime column + tz: the target timezone + + Returns: + The column after conversion to the specified timezone + """ + if tz is None: + return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column)) + else: + if isinstance(tz, dt.tzinfo): + tz_name = tz.tzname(None) + else: + tz_name = tz + return convert_timezone(pandas_lit(tz_name), column) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 5656bbfb14a..34a3376fcc1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -41,6 +41,7 @@ mean, min as min_, sum as sum_, + to_char, to_timestamp_ntz, to_timestamp_tz, typeof, @@ -75,6 +76,8 @@ StringType, StructField, StructType, + TimestampTimeZone, + TimestampType, VariantType, _FractionalType, ) @@ -1273,7 +1276,7 @@ def check_snowpark_pandas_object_in_arg(arg: Any) -> bool: if check_snowpark_pandas_object_in_arg(v): return True else: - from snowflake.snowpark.modin.pandas import DataFrame, Series + from modin.pandas import DataFrame, Series return isinstance(arg, (DataFrame, Series)) @@ -1289,14 +1292,23 @@ def snowpark_to_pandas_helper( ) -> Union[native_pd.Index, native_pd.DataFrame]: """ The helper function retrieves a pandas dataframe from an OrderedDataFrame. Performs necessary type - conversions for variant types on the client. This function issues 2 queries, one metadata query - to retrieve the schema and one query to retrieve the data values. + conversions including + 1. For VARIANT types, OrderedDataFrame.to_pandas may convert datetime like types to string. So we add one `typeof` + column for each variant column and use that metadata to convert datetime like types back to their original types. + 2. For TIMESTAMP_TZ type, OrderedDataFrame.to_pandas will convert them into the local session timezone and lose the + original timezone. So we cast TIMESTAMP_TZ columns to string first and then convert them back after to_pandas to + preserve the original timezone. Note that the actual timezone will be lost in Snowflake backend but only the offset + preserved. + 3. For Timedelta columns, since currently we represent the values using integers, here we need to explicitly cast + them back to Timedelta. Args: frame: The internal frame to convert to pandas Dataframe (or Index if index_only is true) index_only: if true, only turn the index columns into a pandas Index - statement_params: Dictionary of statement level parameters to be passed to conversion function of ordered dataframe abstraction. - kwargs: Additional keyword-only args to pass to internal `to_pandas` conversion for orderded dataframe abstraction. + statement_params: Dictionary of statement level parameters to be passed to conversion function of ordered + dataframe abstraction. + kwargs: Additional keyword-only args to pass to internal `to_pandas` conversion for ordered dataframe + abstraction. Returns: pandas dataframe @@ -1365,7 +1377,7 @@ def snowpark_to_pandas_helper( ) variant_type_identifiers = list(map(lambda t: t[0], variant_type_columns_info)) - # Step 3: Create for each variant type column a separate type column (append at end), and retrieve data values + # Step 3.1: Create for each variant type column a separate type column (append at end), and retrieve data values # (and types for variant type columns). variant_type_typeof_identifiers = ( ordered_dataframe.generate_snowflake_quoted_identifiers( @@ -1384,10 +1396,36 @@ def snowpark_to_pandas_helper( [typeof(col(id)) for id in variant_type_identifiers], ) + # Step 3.2: cast timestamp_tz to string to preserve their original timezone offsets + timestamp_tz_identifiers = [ + info[0] + for info in columns_info + if info[1] == TimestampType(TimestampTimeZone.TZ) + ] + timestamp_tz_str_identifiers = ( + ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[ + f"{unquote_name_if_quoted(id)}_str" for id in timestamp_tz_identifiers + ], + excluded=column_identifiers, + ) + ) + if len(timestamp_tz_identifiers): + ordered_dataframe = append_columns( + ordered_dataframe, + timestamp_tz_str_identifiers, + [ + to_char(col(id), format="YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM") + for id in timestamp_tz_identifiers + ], + ) + # ensure that snowpark_df has unique identifiers, so the native pandas DataFrame object created here # also does have unique column names which is a prerequisite for the post-processing logic following. assert is_duplicate_free( - column_identifiers + variant_type_typeof_identifiers + column_identifiers + + variant_type_typeof_identifiers + + timestamp_tz_str_identifiers ), "Snowpark DataFrame to convert must have unique column identifiers" pandas_df = ordered_dataframe.to_pandas(statement_params=statement_params, **kwargs) @@ -1400,7 +1438,9 @@ def snowpark_to_pandas_helper( # Step 3a: post-process variant type columns, if any exist. id_to_label_mapping = dict( zip( - column_identifiers + variant_type_typeof_identifiers, + column_identifiers + + variant_type_typeof_identifiers + + timestamp_tz_str_identifiers, pandas_df.columns, ) ) @@ -1439,6 +1479,25 @@ def convert_variant_type_to_pandas(row: native_pd.Series) -> Any: id_to_label_mapping[quoted_name] ].apply(lambda value: None if value is None else json.loads(value)) + # Convert timestamp_tz in string back to datetime64tz. + if any( + dtype == TimestampType(TimestampTimeZone.TZ) for (_, dtype) in columns_info + ): + id_to_label_mapping = dict( + zip( + column_identifiers + + variant_type_typeof_identifiers + + timestamp_tz_str_identifiers, + pandas_df.columns, + ) + ) + for ts_id, ts_str_id in zip( + timestamp_tz_identifiers, timestamp_tz_str_identifiers + ): + pandas_df[id_to_label_mapping[ts_id]] = native_pd.to_datetime( + pandas_df[id_to_label_mapping[ts_str_id]] + ) + # Step 5. Return the original amount of columns by stripping any typeof(...) columns appended if # schema contained VariantType. downcast_pandas_df = pandas_df[pandas_df.columns[: len(columns_info)]] @@ -1460,9 +1519,15 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta: downcast_pandas_df.columns, cached_snowpark_pandas_types ): if snowpark_pandas_type is not None and snowpark_pandas_type == timedelta_t: - downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( - convert_str_to_timedelta - ) + # By default, pandas warns, "A value is trying to be set on a + # copy of a slice from a DataFrame" here because we are + # assigning a column to downcast_pandas_df, which is a copy of + # a slice of pandas_df. We don't care what happens to pandas_df, + # so the warning isn't useful to us. + with native_pd.option_context("mode.chained_assignment", None): + downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( + convert_str_to_timedelta + ) # Step 7. postprocessing for return types if index_only: @@ -1493,7 +1558,11 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta: # multiple timezones. So here we cast the index to the index_type when ret = pd.Index(...) above cannot # figure out a non-object dtype. Note that the index_type is a logical type may not be 100% accurate. if is_object_dtype(ret.dtype) and not is_object_dtype(index_type): - ret = ret.astype(index_type) + # TODO: SNOW-1657460 fix index_type for timestamp_tz + try: + ret = ret.astype(index_type) + except ValueError: # e.g., Tz-aware datetime.datetime cannot be converted to datetime64 + pass return ret # to_pandas() does not preserve the index information and will just return a diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 400e98562f9..e971b15b6d6 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -149,8 +149,6 @@ ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, - GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, - GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, AggFuncInfo, AggFuncWithLabel, AggregateColumnOpParameters, @@ -161,7 +159,6 @@ convert_agg_func_arg_to_col_agg_func_map, drop_non_numeric_data_columns, generate_column_agg_info, - generate_rowwise_aggregation_function, get_agg_func_to_col_map, get_pandas_aggr_func_name, get_snowflake_agg_func, @@ -172,6 +169,7 @@ APPLY_LABEL_COLUMN_QUOTED_IDENTIFIER, APPLY_VALUE_COLUMN_QUOTED_IDENTIFIER, DEFAULT_UDTF_PARTITION_SIZE, + GroupbyApplySortMethod, check_return_variant_and_get_return_type, create_udf_for_series_apply, create_udtf_for_apply_axis_1, @@ -184,11 +182,7 @@ sort_apply_udtf_result_columns_by_pandas_positions, ) from snowflake.snowpark.modin.plugin._internal.binary_op_utils import ( - compute_binary_op_between_scalar_and_snowpark_column, - compute_binary_op_between_snowpark_column_and_scalar, - compute_binary_op_between_snowpark_columns, - compute_binary_op_with_fill_value, - is_binary_op_supported, + BinaryOp, merge_label_and_identifier_pairs, prepare_binop_pairs_between_dataframe_and_dataframe, ) @@ -282,6 +276,8 @@ raise_if_to_datetime_not_supported, timedelta_freq_to_nanos, to_snowflake_timestamp_format, + tz_convert_column, + tz_localize_column, ) from snowflake.snowpark.modin.plugin._internal.transpose_utils import ( clean_up_transpose_result_index_and_labels, @@ -1854,7 +1850,7 @@ def _binary_op_scalar_rhs( replace_mapping = {} data_column_snowpark_pandas_types = [] for identifier in self._modin_frame.data_column_snowflake_quoted_identifiers: - expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=col(identifier), lhs_datatype=lambda identifier=identifier: self._modin_frame.get_snowflake_type( @@ -1863,7 +1859,7 @@ def _binary_op_scalar_rhs( rhs=pandas_lit(other), rhs_datatype=lambda: infer_object_type(other), fill_value=fill_value, - ) + ).compute() replace_mapping[identifier] = expression data_column_snowpark_pandas_types.append(snowpark_pandas_type) return SnowflakeQueryCompiler( @@ -1914,7 +1910,7 @@ def _binary_op_list_like_rhs_axis_0( replace_mapping = {} snowpark_pandas_types = [] for identifier in new_frame.data_column_snowflake_quoted_identifiers[:-1]: - expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=col(identifier), lhs_datatype=lambda identifier=identifier: new_frame.get_snowflake_type( @@ -1923,7 +1919,7 @@ def _binary_op_list_like_rhs_axis_0( rhs=col(other_identifier), rhs_datatype=lambda: new_frame.get_snowflake_type(other_identifier), fill_value=fill_value, - ) + ).compute() replace_mapping[identifier] = expression snowpark_pandas_types.append(snowpark_pandas_type) @@ -1986,7 +1982,7 @@ def _binary_op_list_like_rhs_axis_1( # rhs is not guaranteed to be a scalar value - it can be a list-like as well. # Convert all list-like objects to a list. rhs_lit = pandas_lit(rhs) if is_scalar(rhs) else pandas_lit(rhs.tolist()) - expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op, lhs=lhs, lhs_datatype=lambda identifier=identifier: self._modin_frame.get_snowflake_type( @@ -1995,7 +1991,7 @@ def _binary_op_list_like_rhs_axis_1( rhs=rhs_lit, rhs_datatype=lambda rhs=rhs: infer_object_type(rhs), fill_value=fill_value, - ) + ).compute() replace_mapping[identifier] = expression snowpark_pandas_types.append(snowpark_pandas_type) @@ -2041,8 +2037,8 @@ def binary_op( # Native pandas does not support binary operations between a Series and a list-like object. from modin.pandas import Series + from modin.pandas.dataframe import DataFrame - from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar # fail explicitly for unsupported scenarios @@ -2056,7 +2052,7 @@ def binary_op( # match pandas documentation; hence it is omitted in the Snowpark pandas implementation. raise ValueError("Only scalars can be used as fill_value.") - if not is_binary_op_supported(op): + if not BinaryOp.is_binary_op_supported(op): ErrorMessage.not_implemented( f"Snowpark pandas doesn't yet support '{op}' binary operation" ) @@ -2121,7 +2117,7 @@ def binary_op( )[0] # add new column with result as unnamed - new_column_expr, snowpark_pandas_type = compute_binary_op_with_fill_value( + new_column_expr, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=col(lhs_quoted_identifier), lhs_datatype=lambda: aligned_frame.get_snowflake_type( @@ -2132,7 +2128,7 @@ def binary_op( rhs_quoted_identifier ), fill_value=fill_value, - ) + ).compute() # name is dropped when names of series differ. A dropped name is using unnamed series label. new_column_name = ( @@ -3557,42 +3553,22 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) - # Get the column aggregation functions used to check if the function - # preserves Snowpark pandas types. - agg_col_funcs = [] - for _, func in column_to_agg_func.items(): - if is_list_like(func) and not is_named_tuple(func): - for fn in func: - agg_col_funcs.append(fn.func) - else: - agg_col_funcs.append(func.func) # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] new_data_column_quoted_identifiers = [] new_data_column_snowpark_pandas_types = [] - for i in range(len(agg_col_ops)): - col_agg_op = agg_col_ops[i] - col_agg_func = agg_col_funcs[i] - new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label) + for agg_col_op in agg_col_ops: + new_data_column_pandas_labels.append(agg_col_op.agg_pandas_label) new_data_column_quoted_identifiers.append( - col_agg_op.agg_snowflake_quoted_identifier + agg_col_op.agg_snowflake_quoted_identifier + ) + new_data_column_snowpark_pandas_types.append( + agg_col_op.data_type + if isinstance(agg_col_op.data_type, SnowparkPandasType) + and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types + else None ) - if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: - new_data_column_snowpark_pandas_types.append( - col_agg_op.data_type - if isinstance(col_agg_op.data_type, SnowparkPandasType) - else None - ) - elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: - # In the case where the aggregation overrides the type of the output data column - # (e.g. any always returns boolean data columns), set the output Snowpark pandas type - # of the given column to None - new_data_column_snowpark_pandas_types.append(None) # type: ignore - else: - self._raise_not_implemented_error_for_timedelta() - new_data_column_snowpark_pandas_types = None # type: ignore - # The ordering of the named aggregations is changed by us when we process # the agg_kwargs into the func dict (named aggregations on the same # column are moved to be contiguous, see groupby.py::aggregate for an @@ -3645,7 +3621,7 @@ def convert_func_to_agg_func_info( ), agg_pandas_label=None, agg_snowflake_quoted_identifier=row_position_quoted_identifier, - snowflake_agg_func=min_, + snowflake_agg_func=get_snowflake_agg_func("min", agg_kwargs={}, axis=0), ordering_columns=internal_frame.ordering_columns, ) agg_col_ops.append(row_position_agg_column_op) @@ -3757,6 +3733,8 @@ def groupby_apply( agg_args: Any, agg_kwargs: dict[str, Any], series_groupby: bool, + force_single_group: bool = False, + force_list_like_to_series: bool = False, ) -> "SnowflakeQueryCompiler": """ Group according to `by` and `level`, apply a function to each group, and combine the results. @@ -3777,6 +3755,10 @@ def groupby_apply( Keyword arguments to pass to agg_func when applying it to each group. series_groupby: Whether we are performing a SeriesGroupBy.apply() instead of a DataFrameGroupBy.apply() + force_single_group: + Force single group (empty set of group by labels) useful for DataFrame.apply() with axis=0 + force_list_like_to_series: + Force the function result to series if it is list-like Returns ------- @@ -3804,15 +3786,23 @@ def groupby_apply( dropna = groupby_kwargs.get("dropna", True) group_keys = groupby_kwargs.get("group_keys", False) - by_pandas_labels = extract_groupby_column_pandas_labels(self, by, level) + by_pandas_labels = ( + [] + if force_single_group + else extract_groupby_column_pandas_labels(self, by, level) + ) - by_snowflake_quoted_identifiers_list = [ - quoted_identifier - for entry in self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - by_pandas_labels - ) - for quoted_identifier in entry - ] + by_snowflake_quoted_identifiers_list = ( + [] + if force_single_group + else [ + quoted_identifier + for entry in self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + by_pandas_labels + ) + for quoted_identifier in entry + ] + ) snowflake_type_map = self._modin_frame.quoted_identifier_to_snowflake_type() @@ -3846,11 +3836,14 @@ def groupby_apply( ], session=self._modin_frame.ordered_dataframe.session, series_groupby=series_groupby, - by_types=[ + by_types=[] + if force_single_group + else [ snowflake_type_map[quoted_identifier] for quoted_identifier in by_snowflake_quoted_identifiers_list ], existing_identifiers=self._modin_frame.ordered_dataframe._dataframe_ref.snowflake_quoted_identifiers, + force_list_like_to_series=force_list_like_to_series, ) new_internal_df = self._modin_frame.ensure_row_position_column() @@ -3922,9 +3915,9 @@ def groupby_apply( *new_internal_df.index_column_snowflake_quoted_identifiers, *input_data_column_identifiers, ).over( - partition_by=[ - *by_snowflake_quoted_identifiers_list, - ], + partition_by=None + if force_single_group + else [*by_snowflake_quoted_identifiers_list], order_by=row_position_snowflake_quoted_identifier, ), ) @@ -4066,7 +4059,9 @@ def groupby_apply( ordered_dataframe=ordered_dataframe, agg_func=agg_func, by_snowflake_quoted_identifiers_list=by_snowflake_quoted_identifiers_list, - sort_method=groupby_apply_sort_method( + sort_method=GroupbyApplySortMethod.ORIGINAL_ROW_ORDER + if force_single_group + else groupby_apply_sort_method( sort, group_keys, original_row_position_snowflake_quoted_identifier, @@ -5639,8 +5634,6 @@ def agg( args: the arguments passed for the aggregation kwargs: keyword arguments passed for the aggregation function. """ - self._raise_not_implemented_error_for_timedelta() - numeric_only = kwargs.get("numeric_only", False) # Call fallback if the aggregation function passed in the arg is currently not supported # by snowflake engine. @@ -5686,6 +5679,11 @@ def agg( not is_list_like(value) for value in func.values() ) if axis == 1: + if any( + isinstance(t, TimedeltaType) + for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values() + ): + ErrorMessage.not_implemented_for_timedelta("agg(axis=1)") if self.is_multiindex(): # TODO SNOW-1010307 fix axis=1 behavior with MultiIndex ErrorMessage.not_implemented( @@ -5743,9 +5741,9 @@ def agg( pandas_column_labels=frame.data_column_pandas_labels, ) if agg_arg in ("idxmin", "idxmax") - else generate_rowwise_aggregation_function(agg_arg, kwargs)( - *(col(c) for c in data_col_identifiers) - ) + else get_snowflake_agg_func( + agg_arg, kwargs, axis=1 + ).snowpark_aggregation(*(col(c) for c in data_col_identifiers)) for agg_arg in agg_args } pandas_labels = list(agg_col_map.keys()) @@ -5865,7 +5863,13 @@ def generate_agg_qc( index_column_snowflake_quoted_identifiers=[ agg_name_col_quoted_identifier ], - data_column_types=None, + data_column_types=[ + col.data_type + if isinstance(col.data_type, SnowparkPandasType) + and col.snowflake_agg_func.preserves_snowpark_pandas_types + else None + for col in col_agg_infos + ], index_column_types=None, ) return SnowflakeQueryCompiler(single_agg_dataframe) @@ -7377,28 +7381,34 @@ def merge_asof( SnowflakeQueryCompiler """ # TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation - if ( - by - or left_by - or right_by - or left_index - or right_index - or tolerance - or suffixes != ("_x", "_y") - ): + if left_index or right_index or tolerance or suffixes != ("_x", "_y"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method does not currently support parameters " - + "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - + "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ) if direction not in ("backward", "forward"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ) + if direction == "backward": + match_comparator = ( + MatchComparator.GREATER_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.GREATER_THAN + ) + else: + match_comparator = ( + MatchComparator.LESS_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.LESS_THAN + ) + left_frame = self._modin_frame right_frame = right._modin_frame - left_keys, right_keys = join_utils.get_join_keys( + # Get the left and right matching key and quoted identifier corresponding to the match_condition + # There will only be matching key/identifier for each table as there is only a single match condition + left_match_keys, right_match_keys = join_utils.get_join_keys( left=left_frame, right=right_frame, on=on, @@ -7407,42 +7417,62 @@ def merge_asof( left_index=left_index, right_index=right_index, ) - left_match_col = ( + left_match_identifier = ( left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - left_keys + left_match_keys )[0][0] ) - right_match_col = ( + right_match_identifier = ( right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - right_keys + right_match_keys )[0][0] ) - - if direction == "backward": - match_comparator = ( - MatchComparator.GREATER_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.GREATER_THAN + coalesce_config = join_utils.get_coalesce_config( + left_keys=left_match_keys, + right_keys=right_match_keys, + external_join_keys=[], + ) + + # Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition + if by or (left_by and right_by): + left_on_keys, right_on_keys = join_utils.get_join_keys( + left=left_frame, + right=right_frame, + on=by, + left_on=left_by, + right_on=right_by, + ) + left_on_identifiers = [ + ids[0] + for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + left_on_keys + ) + ] + right_on_identifiers = [ + ids[0] + for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + right_on_keys + ) + ] + coalesce_config.extend( + join_utils.get_coalesce_config( + left_keys=left_on_keys, + right_keys=right_on_keys, + external_join_keys=[], + ) ) else: - match_comparator = ( - MatchComparator.LESS_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.LESS_THAN - ) - - coalesce_config = join_utils.get_coalesce_config( - left_keys=left_keys, right_keys=right_keys, external_join_keys=[] - ) + left_on_identifiers = [] + right_on_identifiers = [] joined_frame, _ = join_utils.join( left=left_frame, right=right_frame, + left_on=left_on_identifiers, + right_on=right_on_identifiers, how="asof", - left_on=[left_match_col], - right_on=[right_match_col], - left_match_col=left_match_col, - right_match_col=right_match_col, + left_match_col=left_match_identifier, + right_match_col=right_match_identifier, match_comparator=match_comparator, join_key_coalesce_config=coalesce_config, sort=True, @@ -7888,11 +7918,6 @@ def apply( """ self._raise_not_implemented_error_for_timedelta() - # axis=0 is not supported, raise error. - if axis == 0: - ErrorMessage.not_implemented( - "Snowpark pandas apply API doesn't yet support axis == 0" - ) # Only callables are supported for axis=1 mode for now. if not callable(func) and not isinstance(func, UserDefinedFunction): ErrorMessage.not_implemented( @@ -7909,56 +7934,260 @@ def apply( "Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'" ) - # get input types of all data columns from the dataframe directly - input_types = self._modin_frame.get_snowflake_type( - self._modin_frame.data_column_snowflake_quoted_identifiers - ) + if axis == 0: + frame = self._modin_frame - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native + # To apply function to Dataframe with axis=0, we repurpose the groupby apply function by taking each + # column, as a series, and treat as a single group to apply function. Then collect the column results to + # join together for the final result. + col_results = [] - # current columns - column_index = try_convert_index_to_native(self._modin_frame.data_columns_index) + # If raw, then pass numpy ndarray rather than pandas Series as input to the apply function. + if raw: - # Extract return type from annotations (or lookup for known pandas functions) for func object, - # if not return type could be extracted the variable will hold None. - return_type = deduce_return_type_from_function(func) + def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no cover: adding type hint causes an error when creating udtf. also, skip coverage for this function because coverage tools can't tell that we're executing this function because we execute it in a UDTF. + raw_input_obj = args[0].to_numpy() + args = (raw_input_obj,) + args[1:] + return func(*args, **kwargs) - # Check whether return_type has been extracted. If return type is not - # a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to - # be performed which means that the result of df.apply(axis=1) is always a Series object. - if return_type and not ( - isinstance(return_type, PandasSeriesType) - or isinstance(return_type, ArrayType) - ): - return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1( - func, - column_index, - input_types, - return_type, - udf_args=args, - udf_kwargs=kwargs, - session=self._modin_frame.ordered_dataframe.session, - ) + agg_func = wrapped_func + else: + agg_func = func + + # Accumulate indices of the column results. + col_result_indexes = [] + # Accumulate "is scalar" flags for the column results. + col_result_scalars = [] + + # Loop through each data column of the original df frame + for (column_index, data_column_pair) in enumerate( + zip( + frame.data_column_pandas_labels, + frame.data_column_snowflake_quoted_identifiers, + ) + ): + ( + data_column_pandas_label, + data_column_snowflake_quoted_identifier, + ) = data_column_pair + + # Create a frame for the current data column which we will be passed to the apply function below. + # Note that we maintain the original index because the apply function may access via the index. + data_col_qc = self.take_2d_positional( + index=slice(None, None), columns=[column_index] + ) + + data_col_frame = data_col_qc._modin_frame + + data_col_qc = data_col_qc.groupby_apply( + by=[], + agg_func=agg_func, + axis=0, + groupby_kwargs={"as_index": False, "dropna": False}, + agg_args=args, + agg_kwargs=kwargs, + series_groupby=True, + force_single_group=True, + force_list_like_to_series=True, + ) + + data_col_result_frame = data_col_qc._modin_frame + + # Set the index names and corresponding data column pandas label on the result. + data_col_result_frame = InternalFrame.create( + ordered_dataframe=data_col_result_frame.ordered_dataframe, + data_column_snowflake_quoted_identifiers=data_col_result_frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_labels=[data_column_pandas_label], + data_column_pandas_index_names=data_col_frame.data_column_pandas_index_names, + data_column_types=None, + index_column_snowflake_quoted_identifiers=data_col_result_frame.index_column_snowflake_quoted_identifiers, + index_column_pandas_labels=data_col_result_frame.index_column_pandas_labels, + index_column_types=data_col_result_frame.cached_index_column_snowpark_pandas_types, + ) + + data_col_result_index = ( + data_col_result_frame.index_columns_pandas_index() + ) + col_result_indexes.append(data_col_result_index) + # TODO: For functions like np.sum, when supported, we can know upfront the result is a scalar + # so don't need to look at the index. + col_result_scalars.append( + len(data_col_result_index) == 1 and data_col_result_index[0] == -1 + ) + col_results.append(SnowflakeQueryCompiler(data_col_result_frame)) + + result_is_series = False + + if len(col_results) == 1: + result_is_series = col_result_scalars[0] + qc_result = col_results[0] + + # Squeeze to series if it is single column + qc_result = qc_result.columnarize() + if col_result_scalars[0]: + qc_result = qc_result.reset_index(drop=True) + else: + single_row_output = all(len(index) == 1 for index in col_result_indexes) + if single_row_output: + all_scalar_output = all( + is_scalar for is_scalar in col_result_scalars + ) + if all_scalar_output: + # If the apply function maps all columns to a scalar value, then we need to join them together + # to return as a Series result. + + # Ensure all column results have the same column name so concat will be aligned. + for i, qc in enumerate(col_results): + col_results[i] = qc.set_columns([0]) + + qc_result = col_results[0].concat( + axis=0, + other=col_results[1:], + keys=frame.data_column_pandas_labels, + ) + qc_frame = qc_result._modin_frame + + # Drop the extraneous index column from the original result series. + qc_result = SnowflakeQueryCompiler( + InternalFrame.create( + ordered_dataframe=qc_frame.ordered_dataframe, + data_column_snowflake_quoted_identifiers=qc_frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_labels=qc_frame.data_column_pandas_labels, + data_column_pandas_index_names=qc_frame.data_column_pandas_index_names, + data_column_types=qc_frame.cached_data_column_snowpark_pandas_types, + index_column_snowflake_quoted_identifiers=qc_frame.index_column_snowflake_quoted_identifiers[ + :-1 + ], + index_column_pandas_labels=qc_frame.index_column_pandas_labels[ + :-1 + ], + index_column_types=qc_frame.cached_index_column_snowpark_pandas_types[ + :-1 + ], + ) + ) + + result_is_series = True + else: + no_scalar_output = all( + not is_scalar for is_scalar in col_result_scalars + ) + if no_scalar_output: + # Output is Dataframe + all_same_index = col_result_indexes.count( + col_result_indexes[0] + ) == len(col_result_indexes) + qc_result = col_results[0].concat( + axis=1, other=col_results[1:], sort=not all_same_index + ) + else: + # If there's a mix of scalar and pd.Series output from the apply func, pandas stores the + # pd.Series output as the value, which we do not currently support. + ErrorMessage.not_implemented( + "Nested pd.Series in result is not supported in DataFrame.apply(axis=0)" + ) + else: + if any(is_scalar for is_scalar in col_result_scalars): + # If there's a mix of scalar and pd.Series output from the apply func, pandas stores the + # pd.Series output as the value, which we do not currently support. + ErrorMessage.not_implemented( + "Nested pd.Series in result is not supported in DataFrame.apply(axis=0)" + ) + + duplicate_index_values = not all( + len(i) == len(set(i)) for i in col_result_indexes + ) + + # If there are duplicate index values then align on the index for matching results with Pandas. + if duplicate_index_values: + curr_frame = col_results[0]._modin_frame + for next_qc in col_results[1:]: + curr_frame = join_utils.align( + curr_frame, next_qc._modin_frame, [], [], how="left" + ).result_frame + qc_result = SnowflakeQueryCompiler(curr_frame) + else: + # If there are multiple output series with different indices, then line them up as a series output. + all_same_index = all( + all(i == col_result_indexes[0]) for i in col_result_indexes + ) + # If the col results all have same index then we keep the existing index ordering. + qc_result = col_results[0].concat( + axis=1, other=col_results[1:], sort=not all_same_index + ) + + # If result should be Series then change the data column label appropriately. + if result_is_series: + qc_result_frame = qc_result._modin_frame + qc_result = SnowflakeQueryCompiler( + InternalFrame.create( + ordered_dataframe=qc_result_frame.ordered_dataframe, + data_column_snowflake_quoted_identifiers=qc_result_frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_labels=[MODIN_UNNAMED_SERIES_LABEL], + data_column_pandas_index_names=qc_result_frame.data_column_pandas_index_names, + data_column_types=qc_result_frame.cached_data_column_snowpark_pandas_types, + index_column_snowflake_quoted_identifiers=qc_result_frame.index_column_snowflake_quoted_identifiers, + index_column_pandas_labels=qc_result_frame.index_column_pandas_labels, + index_column_types=qc_result_frame.cached_index_column_snowpark_pandas_types, + ) + ) + + return qc_result else: - # Issue actionable warning for users to consider annotating UDF with type annotations - # for better performance. - function_name = ( - func.__name__ if isinstance(func, Callable) else str(func) # type: ignore[arg-type] + # get input types of all data columns from the dataframe directly + input_types = self._modin_frame.get_snowflake_type( + self._modin_frame.data_column_snowflake_quoted_identifiers ) - WarningMessage.single_warning( - f"Function {function_name} passed to apply does not have type annotations," - f" or Snowpark pandas could not extract type annotations. Executing apply" - f" in slow code path which may result in decreased performance. " - f"To disable this warning and improve performance, consider annotating" - f" {function_name} with type annotations." + + from snowflake.snowpark.modin.pandas.utils import ( + try_convert_index_to_native, ) - # Result may need to get expanded into multiple columns, or return type of func is not known. - # Process using UDTF together with dynamic pivot for either case. - return self._apply_with_udtf_and_dynamic_pivot_along_axis_1( - func, raw, result_type, args, column_index, input_types, **kwargs + # current columns + column_index = try_convert_index_to_native( + self._modin_frame.data_columns_index ) + # Extract return type from annotations (or lookup for known pandas functions) for func object, + # if not return type could be extracted the variable will hold None. + return_type = deduce_return_type_from_function(func) + + # Check whether return_type has been extracted. If return type is not + # a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to + # be performed which means that the result of df.apply(axis=1) is always a Series object. + if return_type and not ( + isinstance(return_type, PandasSeriesType) + or isinstance(return_type, ArrayType) + ): + return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1( + func, + column_index, + input_types, + return_type, + udf_args=args, + udf_kwargs=kwargs, + session=self._modin_frame.ordered_dataframe.session, + ) + else: + # Issue actionable warning for users to consider annotating UDF with type annotations + # for better performance. + function_name = ( + func.__name__ if isinstance(func, Callable) else str(func) # type: ignore[arg-type] + ) + WarningMessage.single_warning( + f"Function {function_name} passed to apply does not have type annotations," + f" or Snowpark pandas could not extract type annotations. Executing apply" + f" in slow code path which may result in decreased performance. " + f"To disable this warning and improve performance, consider annotating" + f" {function_name} with type annotations." + ) + + # Result may need to get expanded into multiple columns, or return type of func is not known. + # Process using UDTF together with dynamic pivot for either case. + return self._apply_with_udtf_and_dynamic_pivot_along_axis_1( + func, raw, result_type, args, column_index, input_types, **kwargs + ) + def applymap( self, func: AggFuncType, @@ -8912,7 +9141,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler": SnowflakeQueryCompiler Transposed new QueryCompiler object. """ - self._raise_not_implemented_error_for_timedelta() + if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1: + # In this case, transpose may lose types. + self._raise_not_implemented_error_for_timedelta() frame = self._modin_frame @@ -10548,7 +10779,7 @@ def _make_discrete_difference_expression( snowpark_pandas_type=None, ) else: - return compute_binary_op_between_snowpark_columns( + return BinaryOp.create( "sub", col(snowflake_quoted_identifier), lambda: column_datatype, @@ -10560,7 +10791,7 @@ def _make_discrete_difference_expression( ) ), lambda: column_datatype, - ) + ).compute() else: # periods is the number of columns to *go back*. @@ -10609,13 +10840,13 @@ def _make_discrete_difference_expression( col1 = cast(col1, IntegerType()) if isinstance(col2_dtype, BooleanType): col2 = cast(col2, IntegerType()) - return compute_binary_op_between_snowpark_columns( + return BinaryOp.create( "sub", col1, lambda: col1_dtype, col2, lambda: col2_dtype, - ) + ).compute() def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler": """ @@ -12296,8 +12527,6 @@ def _quantiles_single_col( column would allow us to create an accurate row position column, but would require a potentially expensive JOIN operator afterwards to apply the correct index labels. """ - self._raise_not_implemented_error_for_timedelta() - assert len(self._modin_frame.data_column_pandas_labels) == 1 if index is not None: @@ -12362,7 +12591,7 @@ def _quantiles_single_col( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_identifier], - data_column_types=None, + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, index_column_types=None, ) # We cannot call astype() directly to convert an index column, so we replicate @@ -13396,6 +13625,16 @@ def _window_agg( } ).frame else: + snowflake_agg_func = get_snowflake_agg_func(agg_func, agg_kwargs, axis=0) + if snowflake_agg_func is None: + # We don't have test coverage for this situation because we + # test individual rolling and expanding methods we've implemented, + # like rolling_sum(), but other rolling methods raise + # NotImplementedError immediately. We also don't support rolling + # agg(), which might take us here. + ErrorMessage.not_implemented( # pragma: no cover + f"Window aggregation does not support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)}" + ) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { # If aggregation is count use count on row_position_quoted_identifier @@ -13406,7 +13645,7 @@ def _window_agg( if agg_func == "count" else count(col(quoted_identifier)).over(window_expr) >= min_periods, - get_snowflake_agg_func(agg_func, agg_kwargs)( + snowflake_agg_func.snowpark_aggregation( # Expanding is cumulative so replace NULL with 0 for sum aggregation builtin("zeroifnull")(col(quoted_identifier)) if window_func == WindowFunction.EXPANDING @@ -14213,7 +14452,7 @@ def _binary_op_between_dataframe_and_series_along_axis_0( ) ) - # Lazify type map here for calling compute_binary_op_between_snowpark_columns. + # Lazify type map here for calling binaryOp.compute. def create_lazy_type_functions( identifiers: list[str], ) -> list[DataTypeGetter]: @@ -14243,12 +14482,9 @@ def create_lazy_type_functions( replace_mapping = {} snowpark_pandas_types = [] for left, left_datatype in zip(left_result_data_identifiers, left_datatypes): - ( - expression, - snowpark_pandas_type, - ) = compute_binary_op_between_snowpark_columns( + (expression, snowpark_pandas_type,) = BinaryOp.create( op, col(left), left_datatype, col(right), right_datatype - ) + ).compute() snowpark_pandas_types.append(snowpark_pandas_type) replace_mapping[left] = expression update_result = joined_frame.result_frame.update_snowflake_quoted_identifiers_with_expressions( @@ -14363,8 +14599,6 @@ def idxmax( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -14389,8 +14623,6 @@ def idxmin( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -14507,14 +14739,14 @@ def infer_sorted_column_labels( replace_mapping = {} data_column_snowpark_pandas_types = [] for p in left_right_pairs: - result_expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + result_expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=p.lhs, lhs_datatype=p.lhs_datatype, rhs=p.rhs, rhs_datatype=p.rhs_datatype, fill_value=fill_value, - ) + ).compute() replace_mapping[p.identifier] = result_expression data_column_snowpark_pandas_types.append(snowpark_pandas_type) # Create restricted frame with only combined / replaced labels. @@ -14781,19 +15013,19 @@ def infer_sorted_column_labels( snowpark_pandas_labels = [] for label, identifier in overlapping_pairs: expression, new_type = ( - compute_binary_op_between_scalar_and_snowpark_column( + BinaryOp.create_with_lhs_scalar( op, series.loc[label], col(identifier), datatype_getters[identifier], - ) + ).compute() if squeeze_self - else compute_binary_op_between_snowpark_column_and_scalar( + else BinaryOp.create_with_rhs_scalar( op, col(identifier), datatype_getters[identifier], series.loc[label], - ) + ).compute() ) snowpark_pandas_labels.append(new_type) replace_mapping[identifier] = expression @@ -16454,34 +16686,59 @@ def dt_tz_localize( tz: Union[str, tzinfo], ambiguous: str = "raise", nonexistent: str = "raise", - ) -> None: + include_index: bool = False, + ) -> "SnowflakeQueryCompiler": """ Localize tz-naive to tz-aware. Args: tz : str, pytz.timezone, optional ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise" nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise" + include_index: Whether to include the index columns in the operation. Returns: BaseQueryCompiler New QueryCompiler containing values with localized time zone. """ - ErrorMessage.not_implemented( - "Snowpark pandas doesn't yet support the method 'Series.dt.tz_localize'" + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if not include_index: + method_name = "Series.dt.tz_localize" + else: + assert is_datetime64_any_dtype(dtype), "column must be datetime" + method_name = "DatetimeIndex.tz_localize" + + if not isinstance(ambiguous, str) or ambiguous != "raise": + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) + if not isinstance(nonexistent, str) or nonexistent != "raise": + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) + + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns( + lambda column: tz_localize_column(column, tz), + include_index, + ) ) - def dt_tz_convert(self, tz: Union[str, tzinfo]) -> None: + def dt_tz_convert( + self, + tz: Union[str, tzinfo], + include_index: bool = False, + ) -> "SnowflakeQueryCompiler": """ Convert time-series data to the specified time zone. Args: tz : str, pytz.timezone + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler containing values with converted time zone. """ - ErrorMessage.not_implemented( - "Snowpark pandas doesn't yet support the method 'Series.dt.tz_convert'" + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns( + lambda column: tz_convert_column(column, tz), + include_index, + ) ) def dt_ceil( @@ -16524,9 +16781,9 @@ def dt_ceil( "column must be datetime or timedelta" ) # pragma: no cover - if ambiguous != "raise": + if not isinstance(ambiguous, str) or ambiguous != "raise": ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) - if nonexistent != "raise": + if not isinstance(nonexistent, str) or nonexistent != "raise": ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) if is_datetime64_any_dtype(dtype): @@ -16604,9 +16861,10 @@ def dt_round( raise AssertionError( "column must be datetime or timedelta" ) # pragma: no cover - if ambiguous != "raise": + + if not isinstance(ambiguous, str) or ambiguous != "raise": ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) - if nonexistent != "raise": + if not isinstance(nonexistent, str) or nonexistent != "raise": ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) if is_datetime64_any_dtype(dtype): @@ -16762,9 +17020,10 @@ def dt_floor( raise AssertionError( "column must be datetime or timedelta" ) # pragma: no cover - if ambiguous != "raise": + + if not isinstance(ambiguous, str) or ambiguous != "raise": ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) - if nonexistent != "raise": + if not isinstance(nonexistent, str) or nonexistent != "raise": ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) if is_datetime64_any_dtype(dtype): @@ -17246,9 +17505,11 @@ def equals( ) replace_mapping = { - p.identifier: compute_binary_op_between_snowpark_columns( + p.identifier: BinaryOp.create( "equal_null", p.lhs, p.lhs_datatype, p.rhs, p.rhs_datatype - ).snowpark_column + ) + .compute() + .snowpark_column for p in left_right_pairs } @@ -17776,7 +18037,7 @@ def compare( right_identifier = result_column_mapper.right_quoted_identifiers_map[ right_identifier ] - op_result = compute_binary_op_between_snowpark_columns( + op_result = BinaryOp.create( op="equal_null", first_operand=col(left_identifier), first_datatype=functools.partial( @@ -17786,7 +18047,7 @@ def compare( second_datatype=functools.partial( lambda col: result_frame.get_snowflake_type(col), right_identifier ), - ) + ).compute() binary_op_result = binary_op_result.append_column( str(left_pandas_label) + "_comparison_result", op_result.snowpark_column, @@ -17897,19 +18158,23 @@ def compare( right_identifier ] - cols_equal = compute_binary_op_between_snowpark_columns( - op="equal_null", - first_operand=col(left_mappped_identifier), - first_datatype=functools.partial( - lambda col: result_frame.get_snowflake_type(col), - left_mappped_identifier, - ), - second_operand=col(right_mapped_identifier), - second_datatype=functools.partial( - lambda col: result_frame.get_snowflake_type(col), - right_mapped_identifier, - ), - ).snowpark_column + cols_equal = ( + BinaryOp.create( + op="equal_null", + first_operand=col(left_mappped_identifier), + first_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), + left_mappped_identifier, + ), + second_operand=col(right_mapped_identifier), + second_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), + right_mapped_identifier, + ), + ) + .compute() + .snowpark_column + ) # Add a column containing the values from `self`, but replace # matching values with null. diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/base.py b/src/snowflake/snowpark/modin/plugin/docstrings/base.py index af50e0379dd..4044f7b675f 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/base.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/base.py @@ -2832,6 +2832,7 @@ def shift(): """ Implement shared functionality between DataFrame and Series for shift. axis argument is only relevant for Dataframe, and should be 0 for Series. + Args: periods : int | Sequence[int] Number of periods to shift. Can be positive or negative. If an iterable of ints, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 6d79d07ab84..f7e93e6c2df 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -730,7 +730,7 @@ def apply(): Parameters ---------- func : function - A Python function object to apply to each column or row, or a Python function decorated with @udf. + A Python function object to apply to each column or row. axis : {0 or 'index', 1 or 'columns'}, default 0 Axis along which the function is applied: @@ -738,8 +738,6 @@ def apply(): * 0 or 'index': apply function to each column. * 1 or 'columns': apply function to each row. - Snowpark pandas does not yet support ``axis=0``. - raw : bool, default False Determines if row or column is passed as a Series or ndarray object: @@ -810,8 +808,6 @@ def apply(): 7. When ``func`` uses any first-party modules or third-party packages inside the function, you need to add these dependencies via ``session.add_import()`` and ``session.add_packages()``. - Alternatively. specify third-party packages with the @udf decorator. When using the @udf decorator, - annotations using PandasSeriesType or PandasDataFrameType are not supported. 8. The Snowpark pandas module cannot currently be referenced inside the definition of ``func``. If you need to call a general pandas API like ``pd.Timestamp`` inside ``func``, @@ -852,22 +848,6 @@ def apply(): 1 14.50 2 24.25 dtype: float64 - - or annotate the function - with the @udf decorator from Snowpark https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.functions.udf. - - >>> from snowflake.snowpark.functions import udf - >>> from snowflake.snowpark.types import DoubleType - >>> @udf(packages=['statsmodels>0.12'], return_type=DoubleType()) - ... def autocorr(column): - ... import pandas as pd - ... import statsmodels.tsa.stattools - ... return pd.Series(statsmodels.tsa.stattools.pacf_ols(column.values)).mean() - ... - >>> df.apply(autocorr, axis=0) # doctest: +SKIP - A 0.857143 - B 0.428571 - dtype: float64 """ def assign(): @@ -1061,8 +1041,6 @@ def transform(): axis : {0 or 'index', 1 or 'columns'}, default 0 If 0 or 'index': apply function to each column. If 1 or 'columns': apply function to each row. - Snowpark pandas currently only supports axis=1, and does not yet support axis=0. - *args Positional arguments to pass to `func`. @@ -1771,7 +1749,7 @@ def info(): ... 'COL2': ['A', 'B', 'C']}) >>> df.info() # doctest: +NORMALIZE_WHITESPACE - + SnowflakeIndex Data columns (total 2 columns): # Column Non-Null Count Dtype diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py index 1d351fd67af..9e4ebd4d257 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py @@ -3428,7 +3428,7 @@ def unique(): >>> pd.Series([pd.Timestamp('2016-01-01', tz='US/Eastern') ... for _ in range(3)]).unique() - array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')], + array([Timestamp('2016-01-01 00:00:00-0500', tz='UTC-05:00')], dtype=object) """ diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py index 88c4029a92c..b05d7d76db6 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py @@ -1858,10 +1858,181 @@ def to_pydatetime(): pass def tz_localize(): - pass + """ + Localize tz-naive Datetime Array/Index to tz-aware Datetime Array/Index. + + This method takes a time zone (tz) naive Datetime Array/Index object and makes this time zone aware. It does not move the time to another time zone. + + This method can also be used to do the inverse – to create a time zone unaware object from an aware object. To that end, pass tz=None. + + Parameters + ---------- + tz : str, pytz.timezone, dateutil.tz.tzfile, datetime.tzinfo or None + Time zone to convert timestamps to. Passing None will remove the time zone information preserving local time. + ambiguous : ‘infer’, ‘NaT’, bool array, default ‘raise’ + When clocks moved backward due to DST, ambiguous times may arise. For example in Central European Time (UTC+01), when going from 03:00 DST to 02:00 non-DST, 02:30:00 local time occurs both at 00:30:00 UTC and at 01:30:00 UTC. In such a situation, the ambiguous parameter dictates how ambiguous times should be handled. + - ‘infer’ will attempt to infer fall dst-transition hours based on order + - bool-ndarray where True signifies a DST time, False signifies a non-DST time (note that this flag is only applicable for ambiguous times) + - ‘NaT’ will return NaT where there are ambiguous times + - ‘raise’ will raise an AmbiguousTimeError if there are ambiguous times. + nonexistent : ‘shift_forward’, ‘shift_backward, ‘NaT’, timedelta, default ‘raise’ + A nonexistent time does not exist in a particular timezone where clocks moved forward due to DST. + - ‘shift_forward’ will shift the nonexistent time forward to the closest existing time + - ‘shift_backward’ will shift the nonexistent time backward to the closest existing time + - ‘NaT’ will return NaT where there are nonexistent times + - timedelta objects will shift nonexistent times by the timedelta + - ‘raise’ will raise an NonExistentTimeError if there are nonexistent times. + + Returns + ------- + Same type as self + Array/Index converted to the specified time zone. + + Raises + ------ + TypeError + If the Datetime Array/Index is tz-aware and tz is not None. + + See also + -------- + DatetimeIndex.tz_convert + Convert tz-aware DatetimeIndex from one time zone to another. + + Examples + -------- + >>> tz_naive = pd.date_range('2018-03-01 09:00', periods=3) + >>> tz_naive + DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', + '2018-03-03 09:00:00'], + dtype='datetime64[ns]', freq=None) + + Localize DatetimeIndex in US/Eastern time zone: + + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP + >>> tz_aware # doctest: +SKIP + DatetimeIndex(['2018-03-01 09:00:00-05:00', + '2018-03-02 09:00:00-05:00', + '2018-03-03 09:00:00-05:00'], + dtype='datetime64[ns, US/Eastern]', freq=None) + + With the tz=None, we can remove the time zone information while keeping the local time (not converted to UTC): + + >>> tz_aware.tz_localize(None) # doctest: +SKIP + DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', + '2018-03-03 09:00:00'], + dtype='datetime64[ns]', freq=None) + + Be careful with DST changes. When there is sequential data, pandas can infer the DST time: + + >>> s = pd.to_datetime(pd.Series(['2018-10-28 01:30:00', + ... '2018-10-28 02:00:00', + ... '2018-10-28 02:30:00', + ... '2018-10-28 02:00:00', + ... '2018-10-28 02:30:00', + ... '2018-10-28 03:00:00', + ... '2018-10-28 03:30:00'])) + >>> s.dt.tz_localize('CET', ambiguous='infer') # doctest: +SKIP + 0 2018-10-28 01:30:00+02:00 + 1 2018-10-28 02:00:00+02:00 + 2 2018-10-28 02:30:00+02:00 + 3 2018-10-28 02:00:00+01:00 + 4 2018-10-28 02:30:00+01:00 + 5 2018-10-28 03:00:00+01:00 + 6 2018-10-28 03:30:00+01:00 + dtype: datetime64[ns, CET] + + In some cases, inferring the DST is impossible. In such cases, you can pass an ndarray to the ambiguous parameter to set the DST explicitly + + >>> s = pd.to_datetime(pd.Series(['2018-10-28 01:20:00', + ... '2018-10-28 02:36:00', + ... '2018-10-28 03:46:00'])) + >>> s.dt.tz_localize('CET', ambiguous=np.array([True, True, False])) # doctest: +SKIP + 0 2018-10-28 01:20:00+02:00 + 1 2018-10-28 02:36:00+02:00 + 2 2018-10-28 03:46:00+01:00 + dtype: datetime64[ns, CET] + + If the DST transition causes nonexistent times, you can shift these dates forward or backwards with a timedelta object or ‘shift_forward’ or ‘shift_backwards’. + + >>> s = pd.to_datetime(pd.Series(['2015-03-29 02:30:00', + ... '2015-03-29 03:30:00'])) + >>> s.dt.tz_localize('Europe/Warsaw', nonexistent='shift_forward') # doctest: +SKIP + 0 2015-03-29 03:00:00+02:00 + 1 2015-03-29 03:30:00+02:00 + dtype: datetime64[ns, Europe/Warsaw] + + >>> s.dt.tz_localize('Europe/Warsaw', nonexistent='shift_backward') # doctest: +SKIP + 0 2015-03-29 01:59:59.999999999+01:00 + 1 2015-03-29 03:30:00+02:00 + dtype: datetime64[ns, Europe/Warsaw] + + >>> s.dt.tz_localize('Europe/Warsaw', nonexistent=pd.Timedelta('1h')) # doctest: +SKIP + 0 2015-03-29 03:30:00+02:00 + 1 2015-03-29 03:30:00+02:00 + dtype: datetime64[ns, Europe/Warsaw] + """ def tz_convert(): - pass + """ + Convert tz-aware Datetime Array/Index from one time zone to another. + + Parameters + ---------- + tz : str, pytz.timezone, dateutil.tz.tzfile, datetime.tzinfo or None + Time zone for time. Corresponding timestamps would be converted to this time zone of the Datetime Array/Index. A tz of None will convert to UTC and remove the timezone information. + + Returns + ------- + Array or Index + + Raises + ------ + TypeError + If Datetime Array/Index is tz-naive. + + See also + DatetimeIndex.tz + A timezone that has a variable offset from UTC. + DatetimeIndex.tz_localize + Localize tz-naive DatetimeIndex to a given time zone, or remove timezone from a tz-aware DatetimeIndex. + + Examples + -------- + With the tz parameter, we can change the DatetimeIndex to other time zones: + + >>> dti = pd.date_range(start='2014-08-01 09:00', + ... freq='h', periods=3, tz='Europe/Berlin') # doctest: +SKIP + + >>> dti # doctest: +SKIP + DatetimeIndex(['2014-08-01 09:00:00+02:00', + '2014-08-01 10:00:00+02:00', + '2014-08-01 11:00:00+02:00'], + dtype='datetime64[ns, Europe/Berlin]', freq='h') + + >>> dti.tz_convert('US/Central') # doctest: +SKIP + DatetimeIndex(['2014-08-01 02:00:00-05:00', + '2014-08-01 03:00:00-05:00', + '2014-08-01 04:00:00-05:00'], + dtype='datetime64[ns, US/Central]', freq='h') + + With the tz=None, we can remove the timezone (after converting to UTC if necessary): + + >>> dti = pd.date_range(start='2014-08-01 09:00', freq='h', + ... periods=3, tz='Europe/Berlin') # doctest: +SKIP + + >>> dti # doctest: +SKIP + DatetimeIndex(['2014-08-01 09:00:00+02:00', + '2014-08-01 10:00:00+02:00', + '2014-08-01 11:00:00+02:00'], + dtype='datetime64[ns, Europe/Berlin]', freq='h') + + >>> dti.tz_convert(None) # doctest: +SKIP + DatetimeIndex(['2014-08-01 07:00:00', + '2014-08-01 08:00:00', + '2014-08-01 09:00:00'], + dtype='datetime64[ns]', freq='h') + """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. def normalize(): pass diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index aeca9d6e305..ecef6e843ba 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -60,7 +60,6 @@ validate_percentile, ) -import snowflake.snowpark.modin.pandas as spd from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, @@ -88,8 +87,6 @@ def register_base_override(method_name: str): for directly overriding methods on BasePandasDataset, we mock this by performing the override on DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary to allow both the docstring extension and method dispatch to work properly. - - Methods annotated here also are automatically instrumented with Snowpark pandas telemetry. """ def decorator(base_method: Any): @@ -103,10 +100,7 @@ def decorator(base_method: Any): series_method = series_method.fget if series_method is None or series_method is parent_method: register_series_accessor(method_name)(base_method) - # TODO: SNOW-1063346 - # Since we still use the vendored version of DataFrame and the overrides for the top-level - # namespace haven't been performed yet, we need to set properties on the vendored version - df_method = getattr(spd.dataframe.DataFrame, method_name, None) + df_method = getattr(pd.DataFrame, method_name, None) if isinstance(df_method, property): df_method = df_method.fget if df_method is None or df_method is parent_method: @@ -176,6 +170,22 @@ def filter( pass # pragma: no cover +@register_base_not_implemented() +def interpolate( + self, + method="linear", + *, + axis=0, + limit=None, + inplace=False, + limit_direction: str | None = None, + limit_area=None, + downcast=lib.no_default, + **kwargs, +): # noqa: PR01, RT01, D200 + pass + + @register_base_not_implemented() def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200 pass # pragma: no cover @@ -813,7 +823,7 @@ def _binary_op( **kwargs, ) - from snowflake.snowpark.modin.pandas.dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # Modin Bug: https://github.com/modin-project/modin/issues/7236 # For a Series interacting with a DataFrame, always return a DataFrame diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py index 5ce836061ab..62c9cab4dc1 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py @@ -7,20 +7,1443 @@ pandas, such as `DataFrame.memory_usage`. """ -from typing import Any, Union +from __future__ import annotations +import collections +import datetime +import functools +import itertools +import sys +import warnings +from typing import ( + IO, + Any, + Callable, + Hashable, + Iterable, + Iterator, + Literal, + Mapping, + Sequence, +) + +import modin.pandas as pd +import numpy as np import pandas as native_pd -from modin.pandas import DataFrame -from pandas._typing import Axis, PythonFuncType -from pandas.core.dtypes.common import is_dict_like, is_list_like +from modin.pandas import DataFrame, Series +from modin.pandas.base import BasePandasDataset +from pandas._libs.lib import NoDefault, no_default +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, + Axis, + CompressionOptions, + FilePath, + FillnaOptions, + IgnoreRaise, + IndexLabel, + Level, + PythonFuncType, + Renamer, + Scalar, + StorageOptions, + Suffixes, + WriteBuffer, +) +from pandas.core.common import apply_if_callable, is_bool_indexer +from pandas.core.dtypes.common import ( + infer_dtype_from_object, + is_bool_dtype, + is_dict_like, + is_list_like, + is_numeric_dtype, +) +from pandas.core.dtypes.inference import is_hashable, is_integer +from pandas.core.indexes.frozen import FrozenList +from pandas.io.formats.printing import pprint_thing +from pandas.util._validators import validate_bool_kwarg + +from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor +from snowflake.snowpark.modin.pandas.groupby import ( + DataFrameGroupBy, + validate_groupby_args, +) +from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( + SnowparkPandasRowPartitionIterator, +) +from snowflake.snowpark.modin.pandas.utils import ( + create_empty_native_pandas_frame, + from_non_pandas, + from_pandas, + is_scalar, + raise_if_native_pandas_objects, + replace_external_data_keys_with_empty_pandas_series, + replace_external_data_keys_with_query_compiler, +) +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + is_snowflake_agg_func, +) +from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated +from snowflake.snowpark.modin.plugin._typing import ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ( + ErrorMessage, + dataframe_not_implemented, +) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import ( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE, + DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE, + DF_SETITEM_SLICE_AS_SCALAR_VALUE, +) +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + _inherit_docstrings, + hashable, + validate_int_kwarg, +) +from snowflake.snowpark.udf import UserDefinedFunction + + +def register_dataframe_not_implemented(): + def decorator(base_method: Any): + func = dataframe_not_implemented()(base_method) + register_dataframe_accessor(base_method.__name__)(func) + return func + + return decorator + + +# === UNIMPLEMENTED METHODS === +# The following methods are not implemented in Snowpark pandas, and must be overridden on the +# frontend. These methods fall into a few categories: +# 1. Would work in Snowpark pandas, but we have not tested it. +# 2. Would work in Snowpark pandas, but requires more SQL queries than we are comfortable with. +# 3. Requires materialization (usually via a frontend _default_to_pandas call). +# 4. Performs operations on a native pandas Index object that are nontrivial for Snowpark pandas to manage. + + +# Avoid overwriting builtin `map` by accident +@register_dataframe_accessor("map") +@dataframe_not_implemented() +def _map(self, func, na_action: str | None = None, **kwargs) -> DataFrame: + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def boxplot( + self, + column=None, + by=None, + ax=None, + fontsize=None, + rot=0, + grid=True, + figsize=None, + layout=None, + return_type=None, + backend=None, + **kwargs, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def combine( + self, other, func, fill_value=None, overwrite=True +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def corrwith( + self, other, axis=0, drop=False, method="pearson", numeric_only=False +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def cov( + self, min_periods=None, ddof: int | None = 1, numeric_only=False +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def dot(self, other): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def hist( + self, + column=None, + by=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax=None, + sharex=False, + sharey=False, + figsize=None, + layout=None, + bins=10, + **kwds, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def isetitem(self, loc, value): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def prod( + self, + axis=None, + skipna=True, + numeric_only=False, + min_count=0, + **kwargs, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +register_dataframe_accessor("product")(prod) + + +@register_dataframe_not_implemented() +def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def reindex_like( + self, + other, + method=None, + copy: bool | None = None, + limit=None, + tolerance=None, +) -> DataFrame: # pragma: no cover + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_gbq( + self, + destination_table, + project_id=None, + chunksize=None, + reauth=False, + if_exists="fail", + auth_local_webserver=True, + table_schema=None, + location=None, + progress_bar=True, + credentials=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_html( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + max_rows=None, + max_cols=None, + show_dimensions=False, + decimal=".", + bold_rows=True, + classes=None, + escape=True, + notebook=False, + border=None, + table_id=None, + render_links=False, + encoding=None, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_parquet( + self, + path=None, + engine="auto", + compression="snappy", + index=None, + partition_cols=None, + storage_options: StorageOptions = None, + **kwargs, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_period( + self, freq=None, axis=0, copy=True +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_records( + self, index=True, column_dtypes=None, index_dtypes=None +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_stata( + self, + path: FilePath | WriteBuffer[bytes], + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime.datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + version: int | None = 114, + convert_strl: Sequence[Hashable] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_xml( + self, + path_or_buffer=None, + index=True, + root_name="data", + row_name="row", + na_rep=None, + attr_cols=None, + elem_cols=None, + namespaces=None, + prefix=None, + encoding="utf-8", + xml_declaration=True, + pretty_print=True, + parser="lxml", + stylesheet=None, + compression="infer", + storage_options=None, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __delitem__(self, key): + pass # pragma: no cover + + +@register_dataframe_accessor("attrs") +@dataframe_not_implemented() +@property +def attrs(self): # noqa: RT01, D200 + pass # pragma: no cover + + +@register_dataframe_accessor("style") +@dataframe_not_implemented() +@property +def style(self): # noqa: RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __reduce__(self): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __divmod__(self, other): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __rdivmod__(self, other): + pass # pragma: no cover + + +# The from_dict and from_records accessors are class methods and cannot be overridden via the +# extensions module, as they need to be foisted onto the namespace directly because they are not +# routed through getattr. To this end, we manually set DataFrame.from_dict to our new method. +@dataframe_not_implemented() +def from_dict( + cls, data, orient="columns", dtype=None, columns=None +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +DataFrame.from_dict = from_dict + + +@dataframe_not_implemented() +def from_records( + cls, + data, + index=None, + exclude=None, + columns=None, + coerce_float=False, + nrows=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +DataFrame.from_records = from_records + + +# === OVERRIDDEN METHODS === +# The below methods have their frontend implementations overridden compared to the version present +# in series.py. This is usually for one of the following reasons: +# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate +# and binary operations; further work is needed to refactor either our implementation or upstream +# modin's implementation. +# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already +# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details. +# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler +# layer is acceptable because we can force the method to raise NotImplementedError, but if a method +# defaults at the frontend, Modin raises a warning and performs the operation by coercing the +# dataset to a native pandas object. Removing these is tracked by +# https://github.com/modin-project/modin/issues/7104 + + +# Snowpark pandas overrides the constructor for two reasons: +# 1. To support the Snowpark pandas lazy index object +# 2. To avoid raising "UserWarning: Distributing object. This may take some time." +# when a literal is passed in as data. +@register_dataframe_accessor("__init__") +def __init__( + self, + data=None, + index=None, + columns=None, + dtype=None, + copy=None, + query_compiler=None, +) -> None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Siblings are other dataframes that share the same query compiler. We + # use this list to update inplace when there is a shallow copy. + from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native + + self._siblings = [] + + # Engine.subscribe(_update_engine) + if isinstance(data, (DataFrame, Series)): + self._query_compiler = data._query_compiler.copy() + if index is not None and any(i not in data.index for i in index): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if isinstance(data, Series): + # We set the column name if it is not in the provided Series + if data.name is None: + self.columns = [0] if columns is None else columns + # If the columns provided are not in the named Series, pandas clears + # the DataFrame and sets columns to the columns provided. + elif columns is not None and data.name not in columns: + self._query_compiler = from_pandas( + self.__constructor__(columns=columns) + )._query_compiler + if index is not None: + self._query_compiler = data.loc[index]._query_compiler + elif columns is None and index is None: + data._add_sibling(self) + else: + if columns is not None and any(i not in data.columns for i in columns): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if index is None: + index = slice(None) + if columns is None: + columns = slice(None) + self._query_compiler = data.loc[index, columns]._query_compiler + + # Check type of data and use appropriate constructor + elif query_compiler is None: + distributed_frame = from_non_pandas(data, index, columns, dtype) + if distributed_frame is not None: + self._query_compiler = distributed_frame._query_compiler + return + + if isinstance(data, native_pd.Index): + pass + elif is_list_like(data) and not is_dict_like(data): + old_dtype = getattr(data, "dtype", None) + values = [ + obj._to_pandas() if isinstance(obj, Series) else obj for obj in data + ] + if isinstance(data, np.ndarray): + data = np.array(values, dtype=old_dtype) + else: + try: + data = type(data)(values, dtype=old_dtype) + except TypeError: + data = values + elif is_dict_like(data) and not isinstance( + data, (native_pd.Series, Series, native_pd.DataFrame, DataFrame) + ): + if columns is not None: + data = {key: value for key, value in data.items() if key in columns} + + if len(data) and all(isinstance(v, Series) for v in data.values()): + from modin.pandas import concat + + new_qc = concat(data.values(), axis=1, keys=data.keys())._query_compiler + + if dtype is not None: + new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) + if index is not None: + new_qc = new_qc.reindex( + axis=0, labels=try_convert_index_to_native(index) + ) + if columns is not None: + new_qc = new_qc.reindex( + axis=1, labels=try_convert_index_to_native(columns) + ) + + self._query_compiler = new_qc + return + + data = { + k: v._to_pandas() if isinstance(v, Series) else v + for k, v in data.items() + } + pandas_df = native_pd.DataFrame( + data=try_convert_index_to_native(data), + index=try_convert_index_to_native(index), + columns=try_convert_index_to_native(columns), + dtype=dtype, + copy=copy, + ) + self._query_compiler = from_pandas(pandas_df)._query_compiler + else: + self._query_compiler = query_compiler + + +@register_dataframe_accessor("__dataframe__") +def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Get a Modin DataFrame that implements the dataframe exchange protocol. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + nan_as_null : bool, default: False + A keyword intended for the consumer to tell the producer + to overwrite null values in the data with ``NaN`` (or ``NaT``). + This currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + allow_copy : bool, default: True + A keyword that defines whether or not the library is allowed + to make a copy of the data. For example, copying data would be necessary + if a library supports strided buffers, given that this protocol + specifies contiguous buffers. Currently, if the flag is set to ``False`` + and a copy is needed, a ``RuntimeError`` will be raised. + + Returns + ------- + ProtocolDataframe + A dataframe object following the dataframe protocol specification. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented( + "Snowpark pandas does not support the DataFrame interchange " + + "protocol method `__dataframe__`. To use Snowpark pandas " + + "DataFrames with third-party libraries that try to call the " + + "`__dataframe__` method, please convert this Snowpark pandas " + + "DataFrame to pandas with `to_pandas()`." + ) + + return self._query_compiler.to_dataframe( + nan_as_null=nan_as_null, allow_copy=allow_copy + ) + + +# Snowpark pandas defaults to axis=1 instead of axis=0 for these; we need to investigate if the same should +# apply to upstream Modin. +@register_dataframe_accessor("__and__") +def __and__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__and__", other, axis=1) + + +@register_dataframe_accessor("__rand__") +def __rand__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__rand__", other, axis=1) + + +@register_dataframe_accessor("__or__") +def __or__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__or__", other, axis=1) + + +@register_dataframe_accessor("__ror__") +def __ror__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__ror__", other, axis=1) + + +# Upstream Modin defaults to pandas in some cases. +@register_dataframe_accessor("apply") +def apply( + self, + func: AggFuncType | UserDefinedFunction, + axis: Axis = 0, + raw: bool = False, + result_type: Literal["expand", "reduce", "broadcast"] | None = None, + args=(), + **kwargs, +): + """ + Apply a function along an axis of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.apply( + func, + axis, + raw=raw, + result_type=result_type, + args=args, + **kwargs, + ) + if not isinstance(query_compiler, type(self._query_compiler)): + # A scalar was returned + return query_compiler + + # If True, it is an unamed series. + # Theoretically, if df.apply returns a Series, it will only be an unnamed series + # because the function is supposed to be series -> scalar. + if query_compiler._modin_frame.is_unnamed_series(): + return Series(query_compiler=query_compiler) + else: + return self.__constructor__(query_compiler=query_compiler) + + +# Snowpark pandas uses a separate QC method, while modin directly calls map. +@register_dataframe_accessor("applymap") +def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not callable(func): + raise TypeError(f"{func} is not callable") + return self.__constructor__( + query_compiler=self._query_compiler.applymap( + func, na_action=na_action, **kwargs + ) + ) + + +# We need to override _get_columns to satisfy +# tests/unit/modin/test_type_annotations.py::test_properties_snow_1374293[_get_columns-type_hints1] +# since Modin doesn't provide this type hint. +def _get_columns(self) -> native_pd.Index: + """ + Get the columns for this Snowpark pandas ``DataFrame``. + + Returns + ------- + Index + The all columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.columns + + +# Snowpark pandas wraps this in an update_in_place +def _set_columns(self, new_columns: Axes) -> None: + """ + Set the columns for this Snowpark pandas ``DataFrame``. + + Parameters + ---------- + new_columns : + The new columns to set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + self._update_inplace( + new_query_compiler=self._query_compiler.set_columns(new_columns) + ) + + +register_dataframe_accessor("columns")(property(_get_columns, _set_columns)) + + +# Snowpark pandas does preprocessing for numeric_only (should be pushed to QC). +@register_dataframe_accessor("corr") +def corr( + self, + method: str | Callable = "pearson", + min_periods: int | None = None, + numeric_only: bool = False, +): # noqa: PR01, RT01, D200 + """ + Compute pairwise correlation of columns, excluding NA/null values. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + corr_df = self + if numeric_only: + corr_df = self.drop( + columns=[ + i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) + ] + ) + return self.__constructor__( + query_compiler=corr_df._query_compiler.corr( + method=method, + min_periods=min_periods, + ) + ) + + +# Snowpark pandas does not respect `ignore_index`, and upstream Modin does not respect `how`. +@register_dataframe_accessor("dropna") +def dropna( + self, + *, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, +): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super(DataFrame, self)._dropna( + axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace + ) + + +# Snowpark pandas uses `self_is_series`, while upstream Modin uses `squeeze_self` and `squeeze_value`. +@register_dataframe_accessor("fillna") +def fillna( + self, + value: Hashable | Mapping | Series | DataFrame = None, + *, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, +) -> DataFrame | None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super(DataFrame, self).fillna( + self_is_series=False, + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + + +# Snowpark pandas does different validation and returns a custom GroupBy object. +@register_dataframe_accessor("groupby") +def groupby( + self, + by=None, + axis: Axis | NoDefault = no_default, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool | NoDefault = no_default, + dropna: bool = True, +): + """ + Group ``DataFrame`` using a mapper or by a ``Series`` of columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if axis is not no_default: + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.groupby with axis=1 is deprecated. Do " + + "`frame.T.groupby(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + "The 'axis' keyword in DataFrame.groupby is deprecated and " + + "will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + validate_groupby_args(by, level, observed) + + axis = self._get_axis_number(axis) + + if axis != 0 and as_index is False: + raise ValueError("as_index=False only valid for axis=0") + + idx_name = None + + if ( + not isinstance(by, Series) + and is_list_like(by) + and len(by) == 1 + # if by is a list-like of (None,), we have to keep it as a list because + # None may be referencing a column or index level whose label is + # `None`, and by=None wold mean that there is no `by` param. + and by[0] is not None + ): + by = by[0] + + if hashable(by) and ( + not callable(by) and not isinstance(by, (native_pd.Grouper, FrozenList)) + ): + idx_name = by + elif isinstance(by, Series): + idx_name = by.name + if by._parent is self: + # if the SnowSeries comes from the current dataframe, + # convert it to labels directly for easy processing + by = by.name + elif is_list_like(by): + if axis == 0 and all( + ( + (hashable(o) and (o in self)) + or isinstance(o, Series) + or (is_list_like(o) and len(o) == len(self.shape[axis])) + ) + for o in by + ): + # plit 'by's into those that belongs to the self (internal_by) + # and those that doesn't (external_by). For SnowSeries that belongs + # to current DataFrame, we convert it to labels for easy process. + internal_by, external_by = [], [] + + for current_by in by: + if hashable(current_by): + internal_by.append(current_by) + elif isinstance(current_by, Series): + if current_by._parent is self: + internal_by.append(current_by.name) + else: + external_by.append(current_by) # pragma: no cover + else: + external_by.append(current_by) + + by = internal_by + external_by + + return DataFrameGroupBy( + self, + by, + axis, + level, + as_index, + sort, + group_keys, + idx_name, + observed=observed, + dropna=dropna, + ) + + +# Upstream Modin uses a proxy DataFrameInfo object +@register_dataframe_accessor("info") +def info( + self, + verbose: bool | None = None, + buf: IO[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool | None = None, + null_counts: bool | None = None, +): # noqa: PR01, D200 + """ + Print a concise summary of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def put_str(src, output_len=None, spaces=2): + src = str(src) + return src.ljust(output_len if output_len else len(src)) + " " * spaces + + def format_size(num): + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if num < 1024.0: + return f"{num:3.1f} {x}" + num /= 1024.0 + return f"{num:3.1f} PB" + + output = [] + + type_line = str(type(self)) + index_line = "SnowflakeIndex" + columns = self.columns + columns_len = len(columns) + dtypes = self.dtypes + dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" + + if max_cols is None: + max_cols = 100 + + exceeds_info_cols = columns_len > max_cols + + if buf is None: + buf = sys.stdout + + if null_counts is None: + null_counts = not exceeds_info_cols + + if verbose is None: + verbose = not exceeds_info_cols + + if null_counts and verbose: + # We're gonna take items from `non_null_count` in a loop, which + # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here + # that will be faster. + non_null_count = self.count()._to_pandas() + + if memory_usage is None: + memory_usage = True + + def get_header(spaces=2): + output = [] + head_label = " # " + column_label = "Column" + null_label = "Non-Null Count" + dtype_label = "Dtype" + non_null_label = " non-null" + delimiter = "-" + + lengths = {} + lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) + lengths["column"] = max( + len(column_label), max(len(pprint_thing(col)) for col in columns) + ) + lengths["dtype"] = len(dtype_label) + dtype_spaces = ( + max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) + - lengths["dtype"] + ) + + header = put_str(head_label, lengths["head"]) + put_str( + column_label, lengths["column"] + ) + if null_counts: + lengths["null"] = max( + len(null_label), + max(len(pprint_thing(x)) for x in non_null_count) + len(non_null_label), + ) + header += put_str(null_label, lengths["null"]) + header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) + + output.append(header) + + delimiters = put_str(delimiter * lengths["head"]) + put_str( + delimiter * lengths["column"] + ) + if null_counts: + delimiters += put_str(delimiter * lengths["null"]) + delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) + output.append(delimiters) + + return output, lengths + + output.extend([type_line, index_line]) + + def verbose_repr(output): + columns_line = f"Data columns (total {len(columns)} columns):" + header, lengths = get_header() + output.extend([columns_line, *header]) + for i, col in enumerate(columns): + i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) + + to_append = put_str(f" {i}", lengths["head"]) + put_str( + col_s, lengths["column"] + ) + if null_counts: + non_null = pprint_thing(non_null_count[col]) + to_append += put_str(f"{non_null} non-null", lengths["null"]) + to_append += put_str(dtype, lengths["dtype"], spaces=0) + output.append(to_append) + + def non_verbose_repr(output): + output.append(columns._summary(name="Columns")) + + if verbose: + verbose_repr(output) + else: + non_verbose_repr(output) + + output.append(dtypes_line) + + if memory_usage: + deep = memory_usage == "deep" + mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() + mem_line = f"memory usage: {format_size(mem_usage_bytes)}" + + output.append(mem_line) + + output.append("") + buf.write("\n".join(output)) + + +# Snowpark pandas does different validation. +@register_dataframe_accessor("insert") +def insert( + self, + loc: int, + column: Hashable, + value: Scalar | AnyArrayLike, + allow_duplicates: bool | NoDefault = no_default, +) -> None: + """ + Insert column into ``DataFrame`` at specified location. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + raise_if_native_pandas_objects(value) + if allow_duplicates is no_default: + allow_duplicates = False + if not allow_duplicates and column in self.columns: + raise ValueError(f"cannot insert {column}, already exists") + + if not isinstance(loc, int): + raise TypeError("loc must be int") + + # If columns labels are multilevel, we implement following behavior (this is + # name native pandas): + # Case 1: if 'column' is tuple it's length must be same as number of levels + # otherwise raise error. + # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in + # empty strings to match the length of column levels in self frame. + if self.columns.nlevels > 1: + if isinstance(column, tuple) and len(column) != self.columns.nlevels: + # same error as native pandas. + raise ValueError("Item must have length equal to number of levels.") + if not isinstance(column, tuple): + # Fill empty strings to match length of levels + suffix = [""] * (self.columns.nlevels - 1) + column = tuple([column] + suffix) + + # Dictionary keys are treated as index column and this should be joined with + # index of target dataframe. This behavior is similar to 'value' being DataFrame + # or Series, so we simply create Series from dict data here. + if isinstance(value, dict): + value = Series(value, name=column) + + if isinstance(value, DataFrame) or ( + isinstance(value, np.ndarray) and len(value.shape) > 1 + ): + # Supported numpy array shapes are + # 1. (N, ) -> Ex. [1, 2, 3] + # 2. (N, 1) -> Ex> [[1], [2], [3]] + if value.shape[1] != 1: + if isinstance(value, DataFrame): + # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin + raise ValueError( + f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." + ) + else: + raise ValueError( + f"Expected a 1D array, got an array with shape {value.shape}" + ) + # Change numpy array shape from (N, 1) to (N, ) + if isinstance(value, np.ndarray): + value = value.squeeze(axis=1) + + if ( + is_list_like(value) + and not isinstance(value, (Series, DataFrame)) + and len(value) != self.shape[0] + and not 0 == self.shape[0] # dataframe holds no rows + ): + raise ValueError( + "Length of values ({}) does not match length of index ({})".format( + len(value), len(self) + ) + ) + if not -len(self.columns) <= loc <= len(self.columns): + raise IndexError( + f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" + ) + elif loc < 0: + raise ValueError("unbounded slice") + + join_on_index = False + if isinstance(value, (Series, DataFrame)): + value = value._query_compiler + join_on_index = True + elif is_list_like(value): + value = Series(value, name=column)._query_compiler + + new_query_compiler = self._query_compiler.insert(loc, column, value, join_on_index) + # In pandas, 'insert' operation is always inplace. + self._update_inplace(new_query_compiler=new_query_compiler) + + +# Snowpark pandas does more specialization based on the type of `values` +@register_dataframe_accessor("isin") +def isin( + self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] +) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(values, dict): + return super(DataFrame, self).isin(values) + elif isinstance(values, Series): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not values.index.is_unique: + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + elif isinstance(values, DataFrame): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not (values.columns.is_unique and values.index.is_unique): + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + else: + if not is_list_like(values): + # throw pandas compatible error + raise TypeError( + "only list-like or dict-like objects are allowed " + f"to be passed to {self.__class__.__name__}.isin(), " + f"you passed a '{type(values).__name__}'" + ) + return super(DataFrame, self).isin(values) + + +# Upstream Modin defaults to pandas for some arguments. +@register_dataframe_accessor("join") +def join( + self, + other: DataFrame | Series | Iterable[DataFrame | Series], + on: IndexLabel | None = None, + how: str = "left", + lsuffix: str = "", + rsuffix: str = "", + sort: bool = False, + validate: str | None = None, +) -> DataFrame: + """ + Join columns of another ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + for o in other if isinstance(other, list) else [other]: + raise_if_native_pandas_objects(o) + + # Similar to native pandas we implement 'join' using 'pd.merge' method. + # Following code is copied from native pandas (with few changes explained below) + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 + if isinstance(other, Series): + # Same error as native pandas. + if other.name is None: + raise ValueError("Other Series must have a name") + other = DataFrame(other) + elif is_list_like(other): + if any([isinstance(o, Series) and o.name is None for o in other]): + raise ValueError("Other Series must have a name") + + if isinstance(other, DataFrame): + if how == "cross": + return pd.merge( + self, + other, + how=how, + on=on, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + return pd.merge( + self, + other, + left_on=on, + how=how, + left_index=on is None, + right_index=True, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + else: # List of DataFrame/Series + # Same error as native pandas. + if on is not None: + raise ValueError( + "Joining multiple DataFrames only supported for joining on index" + ) + + # Same error as native pandas. + if rsuffix or lsuffix: + raise ValueError("Suffixes not supported when joining multiple DataFrames") + + # NOTE: These are not the differences between Snowpark pandas API and pandas behavior + # these are differences between native pandas join behavior when join + # frames have unique index or not. + + # In native pandas logic to join multiple DataFrames/Series is data + # dependent. Under the hood it will either use 'concat' or 'merge' API + # Case 1. If all objects being joined have unique index use 'concat' (axis=1) + # Case 2. Otherwise use 'merge' API by looping through objects left to right. + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 + + # Even though concat (axis=1) and merge are very similar APIs they have + # some differences which leads to inconsistent behavior in native pandas. + # 1. Treatment of un-named Series + # Case #1: Un-named series is allowed in concat API. Objects are joined + # successfully by assigning a number as columns name (see 'concat' API + # documentation for details on treatment of un-named series). + # Case #2: It raises 'ValueError: Other Series must have a name' + + # 2. how='right' + # Case #1: 'concat' API doesn't support right join. It raises + # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' + # Case #2: Merges successfully. + + # 3. Joining frames with duplicate labels but no conflict with other frames + # Example: self = DataFrame(... columns=["A", "B"]) + # other = [DataFrame(... columns=["C", "C"])] + # Case #1: 'ValueError: Indexes have overlapping values' + # Case #2: Merged successfully. -from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor -from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( - is_snowflake_agg_func, -) -from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage -from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage -from snowflake.snowpark.modin.utils import _inherit_docstrings, validate_int_kwarg + # In addition to this, native pandas implementation also leads to another + # type of inconsistency where left.join(other, ...) and + # left.join([other], ...) might behave differently for cases mentioned + # above. + # Example: + # import pandas as pd + # df = pd.DataFrame({"a": [4, 5]}) + # other = pd.Series([1, 2]) + # df.join([other]) # this is successful + # df.join(other) # this raises 'ValueError: Other Series must have a name' + + # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API + # to join multiple DataFrame/Series. So always follow the behavior + # documented as Case #2 above. + + joined = self + for frame in other: + if isinstance(frame, DataFrame): + overlapping_cols = set(joined.columns).intersection(set(frame.columns)) + if len(overlapping_cols) > 0: + # Native pandas raises: 'Indexes have overlapping values' + # We differ slightly from native pandas message to make it more + # useful to users. + raise ValueError( + f"Join dataframes have overlapping column labels: {overlapping_cols}" + ) + joined = pd.merge( + joined, + frame, + how=how, + left_index=True, + right_index=True, + validate=validate, + sort=sort, + suffixes=(None, None), + ) + return joined + + +# Snowpark pandas does extra error checking. +@register_dataframe_accessor("mask") +def mask( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.mask requires an axis parameter (0 or 1) when given a Series" + ) + + return super(DataFrame, self).mask( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + +# Snowpark pandas has a fix for a pandas behavior change. It is available in Modin 0.30.1 (SNOW-1552497). +@register_dataframe_accessor("melt") +def melt( + self, + id_vars=None, + value_vars=None, + var_name=None, + value_name="value", + col_level=None, + ignore_index=True, +): # noqa: PR01, RT01, D200 + """ + Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if id_vars is None: + id_vars = [] + if not is_list_like(id_vars): + id_vars = [id_vars] + if value_vars is None: + # Behavior of Index.difference changed in 2.2.x + # https://github.com/pandas-dev/pandas/pull/55113 + # This change needs upstream to Modin: + # https://github.com/modin-project/modin/issues/7206 + value_vars = self.columns.drop(id_vars) + if var_name is None: + columns_name = self._query_compiler.get_index_name(axis=1) + var_name = columns_name if columns_name is not None else "variable" + return self.__constructor__( + query_compiler=self._query_compiler.melt( + id_vars=id_vars, + value_vars=value_vars, + var_name=var_name, + value_name=value_name, + col_level=col_level, + ignore_index=ignore_index, + ) + ) + + +# Snowpark pandas does more thorough error checking. +@register_dataframe_accessor("merge") +def merge( + self, + right: DataFrame | Series, + how: str = "inner", + on: IndexLabel | None = None, + left_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None, + right_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool = True, + indicator: bool = False, + validate: str | None = None, +) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Raise error if native pandas objects are passed. + raise_if_native_pandas_objects(right) + + if isinstance(right, Series) and right.name is None: + raise ValueError("Cannot merge a Series without a name") + if not isinstance(right, (Series, DataFrame)): + raise TypeError( + f"Can only merge Series or DataFrame objects, a {type(right)} was passed" + ) + + if isinstance(right, Series): + right_column_nlevels = len(right.name) if isinstance(right.name, tuple) else 1 + else: + right_column_nlevels = right.columns.nlevels + if self.columns.nlevels != right_column_nlevels: + # This is deprecated in native pandas. We raise explicit error for this. + raise ValueError( + "Can not merge objects with different column levels." + + f" ({self.columns.nlevels} levels on the left," + + f" {right_column_nlevels} on the right)" + ) + + # Merge empty native pandas dataframes for error checking. Otherwise, it will + # require a lot of logic to be written. This takes care of raising errors for + # following scenarios: + # 1. Only 'left_index' is set to True. + # 2. Only 'right_index is set to True. + # 3. Only 'left_on' is provided. + # 4. Only 'right_on' is provided. + # 5. 'on' and 'left_on' both are provided + # 6. 'on' and 'right_on' both are provided + # 7. 'on' and 'left_index' both are provided + # 8. 'on' and 'right_index' both are provided + # 9. 'left_on' and 'left_index' both are provided + # 10. 'right_on' and 'right_index' both are provided + # 11. Length mismatch between 'left_on' and 'right_on' + # 12. 'left_index' is not a bool + # 13. 'right_index' is not a bool + # 14. 'on' is not None and how='cross' + # 15. 'left_on' is not None and how='cross' + # 16. 'right_on' is not None and how='cross' + # 17. 'left_index' is True and how='cross' + # 18. 'right_index' is True and how='cross' + # 19. Unknown label in 'on', 'left_on' or 'right_on' + # 20. Provided 'suffixes' is not sufficient to resolve conflicts. + # 21. Merging on column with duplicate labels. + # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} + # 23. conflict with existing labels for array-like join key + # 24. 'indicator' argument is not bool or str + # 25. indicator column label conflicts with existing data labels + create_empty_native_pandas_frame(self).merge( + create_empty_native_pandas_frame(right), + on=on, + how=how, + left_on=replace_external_data_keys_with_empty_pandas_series(left_on), + right_on=replace_external_data_keys_with_empty_pandas_series(right_on), + left_index=left_index, + right_index=right_index, + suffixes=suffixes, + indicator=indicator, + ) + + return self.__constructor__( + query_compiler=self._query_compiler.merge( + right._query_compiler, + how=how, + on=on, + left_on=replace_external_data_keys_with_query_compiler(self, left_on), + right_on=replace_external_data_keys_with_query_compiler(right, right_on), + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + copy=copy, + indicator=indicator, + validate=validate, + ) + ) @_inherit_docstrings(native_pd.DataFrame.memory_usage, apilink="pandas.DataFrame") @@ -62,6 +1485,125 @@ def memory_usage(self, index: bool = True, deep: bool = False) -> Any: return native_pd.Series([0] * len(columns), index=columns) +# Snowpark pandas handles `inplace` differently. +@register_dataframe_accessor("replace") +def replace( + self, + to_replace=None, + value=no_default, + inplace: bool = False, + limit=None, + regex: bool = False, + method: str | NoDefault = no_default, +): + """ + Replace values given in `to_replace` with `value`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + new_query_compiler = self._query_compiler.replace( + to_replace=to_replace, + value=value, + limit=limit, + regex=regex, + method=method, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Snowpark pandas interacts with the inplace flag differently. +@register_dataframe_accessor("rename") +def rename( + self, + mapper: Renamer | None = None, + *, + index: Renamer | None = None, + columns: Renamer | None = None, + axis: Axis | None = None, + copy: bool | None = None, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", +) -> DataFrame | None: + """ + Alter axes labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if mapper is None and index is None and columns is None: + raise TypeError("must pass an index to rename") + + if index is not None or columns is not None: + if axis is not None: + raise TypeError( + "Cannot specify both 'axis' and any of 'index' or 'columns'" + ) + elif mapper is not None: + raise TypeError( + "Cannot specify both 'mapper' and any of 'index' or 'columns'" + ) + else: + # use the mapper argument + if axis and self._get_axis_number(axis) == 1: + columns = mapper + else: + index = mapper + + if copy is not None: + WarningMessage.ignored_argument( + operation="dataframe.rename", + argument="copy", + message="copy parameter has been ignored with Snowflake execution engine", + ) + + if isinstance(index, dict): + index = Series(index) + + new_qc = self._query_compiler.rename( + index_renamer=index, columns_renamer=columns, level=level, errors=errors + ) + return self._create_or_update_from_compiler( + new_query_compiler=new_qc, inplace=inplace + ) + + +# Upstream modin converts aggfunc to a cython function if it's a string. +@register_dataframe_accessor("pivot_table") +def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc="mean", + fill_value=None, + margins=False, + dropna=True, + margins_name="All", + observed=False, + sort=True, +): + """ + Create a spreadsheet-style pivot table as a ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + result = self.__constructor__( + query_compiler=self._query_compiler.pivot_table( + index=index, + values=values, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + observed=observed, + sort=sort, + ) + ) + return result + + +# Snowpark pandas produces a different warning for materialization. @register_dataframe_accessor("plot") @property def plot( @@ -108,11 +1650,227 @@ def plot( return self._to_pandas().plot +# Upstream Modin defaults when other is a Series. +@register_dataframe_accessor("pow") +def pow( + self, other, axis="columns", level=None, fill_value=None +): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "pow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + +@register_dataframe_accessor("rpow") +def rpow( + self, other, axis="columns", level=None, fill_value=None +): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rpow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + +# Snowpark pandas does extra argument validation, and uses iloc instead of drop at the end. +@register_dataframe_accessor("select_dtypes") +def select_dtypes( + self, + include: ListLike | str | type | None = None, + exclude: ListLike | str | type | None = None, +) -> DataFrame: + """ + Return a subset of the ``DataFrame``'s columns based on the column dtypes. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This line defers argument validation to pandas, which will raise errors on our behalf in cases + # like if `include` and `exclude` are None, the same type is specified in both lists, or a string + # dtype (as opposed to object) is specified. + native_pd.DataFrame().select_dtypes(include, exclude) + + if include and not is_list_like(include): + include = [include] + elif include is None: + include = [] + if exclude and not is_list_like(exclude): + exclude = [exclude] + elif exclude is None: + exclude = [] + + sel = tuple(map(set, (include, exclude))) + + # The width of the np.int_/float_ alias differs between Windows and other platforms, so + # we need to include a workaround. + # https://github.com/numpy/numpy/issues/9464 + # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 + def check_sized_number_infer_dtypes(dtype): + if (isinstance(dtype, str) and dtype == "int") or (dtype is int): + return [np.int32, np.int64] + elif dtype == "float" or dtype is float: + return [np.float64, np.float32] + else: + return [infer_dtype_from_object(dtype)] + + include, exclude = map( + lambda x: set( + itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) + ), + sel, + ) + # We need to index on column position rather than label in case of duplicates + include_these = native_pd.Series(not bool(include), index=range(len(self.columns))) + exclude_these = native_pd.Series(not bool(exclude), index=range(len(self.columns))) + + def is_dtype_instance_mapper(dtype): + return functools.partial(issubclass, dtype.type) + + for i, dtype in enumerate(self.dtypes): + if include: + include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) + if exclude: + exclude_these[i] = not any(map(is_dtype_instance_mapper(dtype), exclude)) + + dtype_indexer = include_these & exclude_these + indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] + # We need to use iloc instead of drop in case of duplicate column names + return self.iloc[:, indicate] + + +# Snowpark pandas does extra validation on the `axis` argument. +@register_dataframe_accessor("set_axis") +def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, # ignored +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not is_scalar(axis): + raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") + return super(DataFrame, self).set_axis( + labels=labels, + # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. + axis=native_pd.DataFrame._get_axis_name(axis), + copy=copy, + ) + + +# Snowpark pandas needs extra logic for the lazy index class. +@register_dataframe_accessor("set_index") +def set_index( + self, + keys: IndexLabel + | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], + drop: bool = True, + append: bool = False, + inplace: bool = False, + verify_integrity: bool = False, +) -> None | DataFrame: + """ + Set the ``DataFrame`` index using existing columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if not isinstance(keys, list): + keys = [keys] + + # make sure key is either hashable, index, or series + label_or_series = [] + + missing = [] + columns = self.columns.tolist() + for key in keys: + raise_if_native_pandas_objects(key) + if isinstance(key, pd.Series): + label_or_series.append(key._query_compiler) + elif isinstance(key, (np.ndarray, list, Iterator)): + label_or_series.append(pd.Series(key)._query_compiler) + elif isinstance(key, (pd.Index, native_pd.MultiIndex)): + label_or_series += [s._query_compiler for s in self._to_series_list(key)] + else: + if not is_hashable(key): + raise TypeError( + f'The parameter "keys" may be a column key, one-dimensional array, or a list ' + f"containing only valid column keys and one-dimensional arrays. Received column " + f"of type {type(key)}" + ) + label_or_series.append(key) + found = key in columns + if columns.count(key) > 1: + raise ValueError(f"The column label '{key}' is not unique") + elif not found: + missing.append(key) + + if missing: + raise KeyError(f"None of {missing} are in the columns") + + new_query_compiler = self._query_compiler.set_index( + label_or_series, drop=drop, append=append + ) + + # TODO: SNOW-782633 improve this code once duplicate is supported + # this needs to pull all index which is inefficient + if verify_integrity and not new_query_compiler.index.is_unique: + duplicates = new_query_compiler.index[ + new_query_compiler.index.to_pandas().duplicated() + ].unique() + raise ValueError(f"Index has duplicate keys: {duplicates}") + + return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) + + +# Upstream Modin uses `len(self.index)` instead of `len(self)`, which gives an extra query. +@register_dataframe_accessor("shape") +@property +def shape(self) -> tuple[int, int]: + """ + Return a tuple representing the dimensionality of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return len(self), len(self.columns) + + +# Snowpark pands has rewrites to minimize queries from length checks. +@register_dataframe_accessor("squeeze") +def squeeze(self, axis: Axis | None = None): + """ + Squeeze 1 dimensional axis objects into scalars. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) if axis is not None else None + len_columns = self._query_compiler.get_axis_len(1) + if axis == 1 and len_columns == 1: + return Series(query_compiler=self._query_compiler) + if axis in [0, None]: + # get_axis_len(0) results in a sql query to count number of rows in current + # dataframe. We should only compute len_index if axis is 0 or None. + len_index = len(self) + if axis is None and (len_columns == 1 or len_index == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 0 and len_index == 1: + return Series(query_compiler=self.T._query_compiler) + return self.copy() + + # Upstream modin defines sum differently for series/DF, but we use the same implementation for both. @register_dataframe_accessor("sum") def sum( self, - axis: Union[Axis, None] = None, + axis: Axis | None = None, skipna: bool = True, numeric_only: bool = False, min_count: int = 0, @@ -130,6 +1888,70 @@ def sum( ) +# Snowpark pandas raises a warning where modin defaults to pandas. +@register_dataframe_accessor("stack") +def stack( + self, + level: int | str | list = -1, + dropna: bool | NoDefault = no_default, + sort: bool | NoDefault = no_default, + future_stack: bool = False, # ignored +): + """ + Stack the prescribed level(s) from columns to index. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if future_stack is not False: + WarningMessage.ignored_argument( # pragma: no cover + operation="DataFrame.stack", + argument="future_stack", + message="future_stack parameter has been ignored with Snowflake execution engine", + ) + if dropna is NoDefault: + dropna = True # pragma: no cover + if sort is NoDefault: + sort = True # pragma: no cover + + # This ensures that non-pandas MultiIndex objects are caught. + is_multiindex = len(self.columns.names) > 1 + if not is_multiindex or ( + is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.stack(level, dropna, sort) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.stack(level, dropna, sort) + ) + + +# Upstream modin doesn't pass `copy`, so we can't raise a warning for it. +# No need to override the `T` property since that can't take any extra arguments. +@register_dataframe_accessor("transpose") +def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 + """ + Transpose index and columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if copy: + WarningMessage.ignored_argument( + operation="transpose", + argument="copy", + message="Transpose ignore copy argument in Snowpark pandas API", + ) + + if args: + WarningMessage.ignored_argument( + operation="transpose", + argument="args", + message="Transpose ignores args in Snowpark pandas API", + ) + + return self.__constructor__(query_compiler=self._query_compiler.transpose()) + + +# Upstream modin implements transform in base.py, but we don't yet support Series.transform. @register_dataframe_accessor("transform") def transform( self, func: PythonFuncType, axis: Axis = 0, *args: Any, **kwargs: Any @@ -151,3 +1973,380 @@ def transform( raise ValueError("Function did not transform") return self.apply(func, axis, False, args=args, **kwargs) + + +# Upstream modin defaults to pandas for some arguments. +@register_dataframe_accessor("unstack") +def unstack( + self, + level: int | str | list = -1, + fill_value: int | str | dict = None, + sort: bool = True, +): + """ + Pivot a level of the (necessarily hierarchical) index labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This ensures that non-pandas MultiIndex objects are caught. + nlevels = self._query_compiler.nlevels() + is_multiindex = nlevels > 1 + + if not is_multiindex or ( + is_multiindex and is_list_like(level) and len(level) == nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=False + ) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=False + ) + ) + + +# Upstream modin does different validation and sorting. +@register_dataframe_accessor("value_counts") +def value_counts( + self, + subset: Sequence[Hashable] | None = None, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + dropna: bool = True, +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return Series( + query_compiler=self._query_compiler.value_counts( + subset=subset, + normalize=normalize, + sort=sort, + ascending=ascending, + dropna=dropna, + ), + name="proportion" if normalize else "count", + ) + + +@register_dataframe_accessor("where") +def where( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.where requires an axis parameter (0 or 1) when given a Series" + ) + + return super(DataFrame, self).where( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + +# Snowpark pandas has a custom iterator. +@register_dataframe_accessor("iterrows") +def iterrows(self) -> Iterator[tuple[Hashable, Series]]: + """ + Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def iterrow_builder(s): + """Return tuple of the given `s` parameter name and the parameter themselves.""" + return s.name, s + + # Raise warning message since iterrows is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") + ) + + partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) + yield from partition_iterator + + +# Snowpark pandas has a custom iterator. +@register_dataframe_accessor("itertuples") +def itertuples( + self, index: bool = True, name: str | None = "Pandas" +) -> Iterable[tuple[Any, ...]]: + """ + Iterate over ``DataFrame`` rows as ``namedtuple``-s. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + + def itertuples_builder(s): + """Return the next namedtuple.""" + # s is the Series of values in the current row. + fields = [] # column names + data = [] # values under each column + + if index: + data.append(s.name) + fields.append("Index") + + # Fill column names and values. + fields.extend(list(self.columns)) + data.extend(s) + + if name is not None: + # Creating the namedtuple. + itertuple = collections.namedtuple(name, fields, rename=True) + return itertuple._make(data) + + # When the name is None, return a regular tuple. + return tuple(data) + + # Raise warning message since itertuples is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") + ) + return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) + + +# Snowpark pandas truncates the repr output. +@register_dataframe_accessor("__repr__") +def __repr__(self): + """ + Return a string representation for a particular ``DataFrame``. + + Returns + ------- + str + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + num_rows = native_pd.get_option("display.max_rows") or len(self) + # see _repr_html_ for comment, allow here also all column behavior + num_cols = native_pd.get_option("display.max_columns") or len(self.columns) + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") + result = repr(repr_df) + + # if truncated, add shape information + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # The split here is so that we don't repr pandas row lengths. + return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( + row_count, col_count + ) + else: + return result + + +# Snowpark pandas uses a different default `num_rows` value. +@register_dataframe_accessor("_repr_html_") +def _repr_html_(self): # pragma: no cover + """ + Return a html representation for a particular ``DataFrame``. + + Returns + ------- + str + + Notes + ----- + Supports pandas `display.max_rows` and `display.max_columns` options. + """ + num_rows = native_pd.get_option("display.max_rows") or 60 + # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow + # here value=0 which means display all columns. + num_cols = native_pd.get_option("display.max_columns") + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols) + result = repr_df._repr_html_() + + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # We split so that we insert our correct dataframe dimensions. + return ( + result.split("

")[0] + + f"

{row_count} rows × {col_count} columns

\n" + ) + else: + return result + + +# Upstream modin just uses `to_datetime` rather than `dataframe_to_datetime` on the query compiler. +@register_dataframe_accessor("_to_datetime") +def _to_datetime(self, **kwargs): + """ + Convert `self` to datetime. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_datetime` invocation. + + Returns + ------- + Series of datetime64 dtype + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._reduce_dimension( + query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) + ) + + +# Snowpark pandas has the extra `statement_params` argument. +@register_dataframe_accessor("_to_pandas") +def _to_pandas( + self, + *, + statement_params: dict[str, str] | None = None, + **kwargs: Any, +) -> native_pd.DataFrame: + """ + Convert Snowpark pandas DataFrame to pandas DataFrame + + Args: + statement_params: Dictionary of statement level parameters to be set while executing this action. + + Returns: + pandas DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.to_pandas(statement_params=statement_params, **kwargs) + + +# Snowpark pandas does more validation and error checking than upstream Modin, and uses different +# helper methods for dispatch. +@register_dataframe_accessor("__setitem__") +def __setitem__(self, key: Any, value: Any): + """ + Set attribute `value` identified by `key`. + + Args: + key: Key to set + value: Value to set + + Note: + In the case where value is any list like or array, pandas checks the array length against the number of rows + of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw + a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if + the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use + enlargement filling with the last value in the array. + + Returns: + None + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + key = apply_if_callable(key, self) + if isinstance(key, DataFrame) or ( + isinstance(key, np.ndarray) and len(key.shape) == 2 + ): + # This case uses mask's codepath to perform the set, but + # we need to duplicate the code here since we are passing + # an additional kwarg `cond_fillna_with_true` to the QC here. + # We need this additional kwarg, since if df.shape + # and key.shape do not align (i.e. df has more rows), + # mask's codepath would mask the additional rows in df + # while for setitem, we need to keep the original values. + if not isinstance(key, DataFrame): + if key.dtype != bool: + raise TypeError( + "Must pass DataFrame or 2-d ndarray with boolean values only" + ) + key = DataFrame(key) + key._query_compiler._shape_hint = "array" + + if value is not None: + value = apply_if_callable(value, self) + + if isinstance(value, np.ndarray): + value = DataFrame(value) + value._query_compiler._shape_hint = "array" + elif isinstance(value, pd.Series): + # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this + # error instead, since it is more descriptive. + raise ValueError( + "setitem with a 2D key does not support Series values." + ) + + if isinstance(value, BasePandasDataset): + value = value._query_compiler + + query_compiler = self._query_compiler.mask( + cond=key._query_compiler, + other=value, + axis=None, + level=None, + cond_fillna_with_true=True, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace=True) + + # Error Checking: + if (isinstance(key, pd.Series) or is_list_like(key)) and (isinstance(value, range)): + raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) + elif isinstance(value, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) + + # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column + # key. + index, columns = slice(None), key + index_is_bool_indexer = False + if isinstance(key, slice): + if is_integer(key.start) and is_integer(key.stop): + # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as + # df.iloc[1:2, :] = val + self.iloc[key] = value + return + index, columns = key, slice(None) + elif isinstance(key, pd.Series): + if is_bool_dtype(key.dtype): + index, columns = key, slice(None) + index_is_bool_indexer = True + elif is_bool_indexer(key): + index, columns = pd.Series(key), slice(None) + index_is_bool_indexer = True + + # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case + # we have to explicitly set matching_item_columns_by_label to False for setitem. + index = index._query_compiler if isinstance(index, BasePandasDataset) else index + columns = ( + columns._query_compiler if isinstance(columns, BasePandasDataset) else columns + ) + from snowflake.snowpark.modin.pandas.indexing import is_2d_array + + matching_item_rows_by_label = not is_2d_array(value) + if is_2d_array(value): + value = DataFrame(value) + item = value._query_compiler if isinstance(value, BasePandasDataset) else value + new_qc = self._query_compiler.set_2d_labels( + index, + columns, + item, + # setitem always matches item by position + matching_item_columns_by_label=False, + matching_item_rows_by_label=matching_item_rows_by_label, + index_is_bool_indexer=index_is_bool_indexer, + # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling + # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the + # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have + # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns + # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", + # "X", "X". + deduplicate_columns=True, + ) + return self._update_inplace(new_query_compiler=new_qc) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index 7be7adb54c1..38edb9f7bee 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -229,7 +229,7 @@ def __init__( -------- >>> idx = pd.DatetimeIndex(["1/1/2020 10:00:00+00:00", "2/1/2020 11:00:00+00:00"], tz="America/Los_Angeles") >>> idx - DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, UTC-08:00]', freq=None) """ # DatetimeIndex is already initialized in __new__ method. We keep this method # only for docstring generation. @@ -960,7 +960,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex: DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None) """ - @datetime_index_not_implemented() def tz_convert(self, tz) -> DatetimeIndex: """ Convert tz-aware Datetime Array/Index from one time zone to another. @@ -1025,8 +1024,14 @@ def tz_convert(self, tz) -> DatetimeIndex: '2014-08-01 09:00:00'], dtype='datetime64[ns]', freq='h') """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_convert( + tz, + include_index=True, + ) + ) - @datetime_index_not_implemented() def tz_localize( self, tz, @@ -1104,21 +1109,29 @@ def tz_localize( Localize DatetimeIndex in US/Eastern time zone: - >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP - >>> tz_aware # doctest: +SKIP - DatetimeIndex(['2018-03-01 09:00:00-05:00', - '2018-03-02 09:00:00-05:00', + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') + >>> tz_aware + DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00', '2018-03-03 09:00:00-05:00'], - dtype='datetime64[ns, US/Eastern]', freq=None) + dtype='datetime64[ns, UTC-05:00]', freq=None) With the ``tz=None``, we can remove the time zone information while keeping the local time (not converted to UTC): - >>> tz_aware.tz_localize(None) # doctest: +SKIP + >>> tz_aware.tz_localize(None) DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', '2018-03-03 09:00:00'], dtype='datetime64[ns]', freq=None) """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_localize( + tz, + ambiguous, + nonexistent, + include_index=True, + ) + ) def round( self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 12710224de7..b25bb481dc0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -30,6 +30,7 @@ import modin import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import get_option from pandas._libs import lib @@ -49,7 +50,6 @@ ) from pandas.core.dtypes.inference import is_hashable -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin @@ -71,6 +71,35 @@ } +class IndexParent: + def __init__(self, parent: DataFrame | Series) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + + Parameters + ---------- + parent : DataFrame or Series + The parent object that the Index is a part of. + """ + assert isinstance(parent, (DataFrame, Series)) + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the query compiler associated with the parent is + different from the original query compiler recorded, i.e., an inplace update has been applied to the parent. + """ + if self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,7 +164,7 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. + # `_parent` keeps track of the parent object that this Index is a part of. index._parent = None return index @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any: ErrorMessage.not_implemented(f"Index.{key} is not yet implemented") raise err + def _set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = IndexParent(parent) + def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) @@ -408,12 +448,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - @property def values(self) -> ArrayLike: """ @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None: if not is_hashable(value): raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) + # Update the name of the parent's index only if an inplace update is performed on + # the parent object, i.e., the parent's current query compiler matches the originally + # recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None: if isinstance(values, Index): values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) + # Update the name of the parent's index only if the parent's current query compiler + # matches the recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py index c5f9e4f6cee..43f9603cfb4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py @@ -9,10 +9,10 @@ import inspect from typing import Any, Iterable, Literal, Optional, Union +from modin.pandas import DataFrame, Series from pandas._typing import IndexLabel from snowflake.snowpark import DataFrame as SnowparkDataFrame -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py index dea98bbb0d3..6d6fb4cd0bd 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py @@ -15,6 +15,7 @@ ) import pandas as native_pd +from modin.pandas import DataFrame from pandas._libs.lib import NoDefault, no_default from pandas._typing import ( CSVEngine, @@ -26,7 +27,6 @@ ) import snowflake.snowpark.modin.pandas as pd -from snowflake.snowpark.modin.pandas import DataFrame from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py index f7bba4c743a..5b245bfdab4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py @@ -181,7 +181,6 @@ def to_pandas( See Also: - :func:`to_pandas ` - - :func:`DataFrame.to_pandas ` Returns: pandas Series diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 5011defa685..b104c223e26 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -9,22 +9,13 @@ from __future__ import annotations -from typing import ( - IO, - TYPE_CHECKING, - Any, - Callable, - Hashable, - Literal, - Mapping, - Sequence, -) +from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence import modin.pandas as pd import numpy as np import numpy.typing as npt import pandas as native_pd -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas._libs.lib import NoDefault, is_integer, no_default from pandas._typing import ( @@ -73,9 +64,6 @@ validate_int_kwarg, ) -if TYPE_CHECKING: - from modin.pandas import DataFrame - def register_series_not_implemented(): def decorator(base_method: Any): @@ -209,21 +197,6 @@ def hist( pass # pragma: no cover -@register_series_not_implemented() -def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, -): # noqa: PR01, RT01, D200 - pass # pragma: no cover - - @register_series_not_implemented() def item(self): # noqa: RT01, D200 pass # pragma: no cover @@ -1451,9 +1424,7 @@ def set_axis( ) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. +# Snowpark pandas does different validation. @register_series_accessor("rename") def rename( self, @@ -1503,9 +1474,36 @@ def rename( return self_cp -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. +# Modin defaults to pandas for some arguments for unstack +@register_series_accessor("unstack") +def unstack( + self, + level: int | str | list = -1, + fill_value: int | str | dict = None, + sort: bool = True, +): + """ + Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from modin.pandas.dataframe import DataFrame + + # We can't unstack a Series object, if we don't have a MultiIndex. + if self._query_compiler.has_multiindex: + result = DataFrame( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=True + ) + ) + else: + raise ValueError( # pragma: no cover + f"index must be a MultiIndex to unstack, {type(self.index)} was passed" + ) + + return result + + +# Snowpark pandas does an extra check on `len(ascending)`. @register_series_accessor("sort_values") def sort_values( self, @@ -1521,7 +1519,7 @@ def sort_values( Sort by the values. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame + from modin.pandas.dataframe import DataFrame if is_list_like(ascending) and len(ascending) != 1: raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series") @@ -1550,38 +1548,6 @@ def sort_values( return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. -# Modin also defaults to pandas for some arguments for unstack -@register_series_accessor("unstack") -def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, -): - """ - Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - # We can't unstack a Series object, if we don't have a MultiIndex. - if self._query_compiler.has_multiindex: - result = DataFrame( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=True - ) - ) - else: - raise ValueError( # pragma: no cover - f"index must be a MultiIndex to unstack, {type(self.index)} was passed" - ) - - return result - - # Upstream Modin defaults at the frontend layer. @register_series_accessor("where") def where( @@ -1727,63 +1693,6 @@ def to_dict(self, into: type[dict] = dict) -> dict: return self._to_pandas().to_dict(into=into) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored -# version of DataFrame, we must keep this override. -@register_series_accessor("_create_or_update_from_compiler") -def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a Series with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - Series, DataFrame or None - None if update was done, Series or DataFrame otherwise. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace and new_query_compiler.is_series_like(): - return self.__constructor__(query_compiler=new_query_compiler) - elif not inplace: - # This can happen with things like `reset_index` where we can add columns. - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - return DataFrame(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored -# version of DataFrame, we must keep this override. -@register_series_accessor("to_frame") -def to_frame(self, name: Hashable = no_default) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Convert Series to {label -> value} dict or dict-like object. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - if name is None: - name = no_default - - self_cp = self.copy() - if name is not no_default: - self_cp.name = name - - return DataFrame(self_cp) - - @register_series_accessor("to_numpy") def to_numpy( self, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 96e2913f556..1cd5e31c63f 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -28,11 +28,11 @@ import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from pandas._libs import lib from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable from pandas.core.dtypes.common import is_timedelta64_dtype -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -392,12 +392,11 @@ def to_pytimedelta(self) -> np.ndarray: datetime.timedelta(days=3)], dtype=object) """ - @timedelta_index_not_implemented() def mean( self, *, skipna: bool = True, axis: AxisInt | None = 0 - ) -> native_pd.Timestamp: + ) -> native_pd.Timedelta: """ - Return the mean value of the Array. + Return the mean value of the Timedelta values. Parameters ---------- @@ -407,17 +406,46 @@ def mean( Returns ------- - scalar Timestamp + scalar Timedelta + + Examples + -------- + >>> idx = pd.to_timedelta([1, 2, 3, 1], unit='D') + >>> idx + TimedeltaIndex(['1 days', '2 days', '3 days', '1 days'], dtype='timedelta64[ns]', freq=None) + >>> idx.mean() + Timedelta('1 days 18:00:00') See Also -------- numpy.ndarray.mean : Returns the average of array elements along a given axis. Series.mean : Return the mean value in a Series. - - Notes - ----- - mean is only defined for Datetime and Timedelta dtypes, not for Period. """ + if axis: + # Native pandas raises IndexError: tuple index out of range + # We raise a different more user-friendly error message. + raise ValueError( + f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'" + ) + pandas_dataframe_result = ( + # reset_index(drop=False) copies the index column of + # self._query_compiler into a new data column. Use `drop=False` + # so that we don't have to use SQL row_number() to generate a new + # index column. + self._query_compiler.reset_index(drop=False) + # Aggregate the data column. + .agg("mean", axis=0, args=(), kwargs={"skipna": skipna}) + # convert the query compiler to a pandas dataframe with + # dimensions 1x1 (note that the frame has a single row even + # if `self` is empty.) + .to_pandas() + ) + assert pandas_dataframe_result.shape == ( + 1, + 1, + ), "Internal error: aggregation result is not 1x1." + # Return the only element in the frame. + return pandas_dataframe_result.iloc[0, 0] @timedelta_index_not_implemented() def as_unit(self, unit: str) -> TimedeltaIndex: diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py index 785a492ca89..f3102115a32 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py +++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py @@ -42,3 +42,17 @@ SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE = ( "Scalar key incompatible with {} value" ) + +DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( + "Currently do not support Series or list-like keys with range-like values" +) + +DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( + "Currently do not support assigning a slice value as if it's a scalar value" +) + +DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( + "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " + "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " + "can work on the entire DataFrame in one shot." +) diff --git a/src/snowflake/snowpark/modin/utils.py b/src/snowflake/snowpark/modin/utils.py index b1027f00e33..b3446ca0362 100644 --- a/src/snowflake/snowpark/modin/utils.py +++ b/src/snowflake/snowpark/modin/utils.py @@ -1171,7 +1171,7 @@ def validate_int_kwarg(value: int, arg_name: str, float_allowed: bool = False) - def doc_replace_dataframe_with_link(_obj: Any, doc: str) -> str: """ Helper function to be passed as the `modify_doc` parameter to `_inherit_docstrings`. This replaces - all unqualified instances of "DataFrame" with ":class:`~snowflake.snowpark.pandas.DataFrame`" to + all unqualified instances of "DataFrame" with ":class:`~modin.pandas.DataFrame`" to prevent it from linking automatically to snowflake.snowpark.DataFrame: see SNOW-1233342. To prevent it from overzealously replacing examples in doctests or already-qualified paths, it diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index c6faa5c9b3b..a586cb7c000 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -221,6 +221,16 @@ _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = ( "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION" ) +_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND = ( + "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND" +) +_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = ( + "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND" +) +# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT +# in Snowflake. This is the limit where we start seeing compilation errors. +DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 +DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None @@ -575,14 +585,22 @@ def __init__( _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION, False ) ) + # The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT + # in Snowflake. This is the limit where we start seeing compilation errors. + self._large_query_breakdown_complexity_bounds: Tuple[int, int] = ( + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + ), + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + ), + ) self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) - self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - if self._auto_clean_up_temp_table_enabled: - self._temp_table_auto_cleaner.start() - _logger.info("Snowpark Session information: %s", self._session_info) def __enter__(self): @@ -621,8 +639,8 @@ def close(self) -> None: raise SnowparkClientExceptionMessages.SERVER_FAILED_CLOSE_SESSION(str(ex)) finally: try: - self._conn.close() self._temp_table_auto_cleaner.stop() + self._conn.close() _logger.info("Closed session: %s", self._session_id) finally: _remove_session(self) @@ -656,10 +674,33 @@ def auto_clean_up_temp_table_enabled(self) -> bool: :meth:`DataFrame.cache_result` in the current session when the DataFrame is no longer referenced (i.e., gets garbage collected). The default value is ``False``. + Example:: + + >>> import gc + >>> + >>> def f(session: Session) -> str: + ... df = session.create_dataframe( + ... [[1, 2], [3, 4]], schema=["a", "b"] + ... ).cache_result() + ... return df.table_name + ... + >>> session.auto_clean_up_temp_table_enabled = True + >>> table_name = f(session) + >>> assert table_name + >>> gc.collect() # doctest: +SKIP + >>> + >>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced + >>> # outside the function + >>> session.sql(f"show tables like '{table_name}'").count() + 0 + + >>> session.auto_clean_up_temp_table_enabled = False + Note: - Even if this parameter is ``False``, Snowpark still records temporary tables when - their corresponding DataFrame are garbage collected. Therefore, if you turn it on in the middle of your session or after turning it off, - the target temporary tables will still be cleaned up accordingly. + Temporary tables will only be dropped if this parameter is enabled during garbage collection. + If a temporary table is no longer referenced when the parameter is on, it will be dropped during garbage collection. + However, if garbage collection occurs while the parameter is off, the table will not be removed. + Note that Python's garbage collection is triggered opportunistically, with no guaranteed timing. """ return self._auto_clean_up_temp_table_enabled @@ -667,6 +708,10 @@ def auto_clean_up_temp_table_enabled(self) -> bool: def large_query_breakdown_enabled(self) -> bool: return self._large_query_breakdown_enabled + @property + def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]: + return self._large_query_breakdown_complexity_bounds + @property def custom_package_usage_config(self) -> Dict: """Get or set configuration parameters related to usage of custom Python packages in Snowflake. @@ -753,11 +798,6 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None: self._session_id, value ) self._auto_clean_up_temp_table_enabled = value - is_alive = self._temp_table_auto_cleaner.is_alive() - if value and not is_alive: - self._temp_table_auto_cleaner.start() - elif not value and is_alive: - self._temp_table_auto_cleaner.stop() else: raise ValueError( "value for auto_clean_up_temp_table_enabled must be True or False!" @@ -782,6 +822,24 @@ def large_query_breakdown_enabled(self, value: bool) -> None: "value for large_query_breakdown_enabled must be True or False!" ) + @large_query_breakdown_complexity_bounds.setter + def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: + """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" + + if len(value) != 2: + raise ValueError( + f"Expecting a tuple of two integers. Got a tuple of length {len(value)}" + ) + if value[0] >= value[1]: + raise ValueError( + f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})" + ) + self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( + self._session_id, value[0], value[1] + ) + + self._large_query_breakdown_complexity_bounds = value + @custom_package_usage_config.setter @experimental_parameter(version="1.6.0") def custom_package_usage_config(self, config: Dict) -> None: @@ -1649,8 +1707,8 @@ def _upload_unsupported_packages( try: # Setup a temporary directory and target folder where pip install will take place. - self._tmpdir_handler = tempfile.TemporaryDirectory() - tmpdir = self._tmpdir_handler.name + tmpdir_handler = tempfile.TemporaryDirectory() + tmpdir = tmpdir_handler.name target = os.path.join(tmpdir, "unsupported_packages") if not os.path.exists(target): os.makedirs(target) @@ -1735,9 +1793,7 @@ def _upload_unsupported_packages( for requirement in supported_dependencies + new_dependencies ] ) - metadata_local_path = os.path.join( - self._tmpdir_handler.name, metadata_file - ) + metadata_local_path = os.path.join(tmpdir_handler.name, metadata_file) with open(metadata_local_path, "w") as file: for key, value in metadata.items(): file.write(f"{key},{value}\n") @@ -1773,9 +1829,8 @@ def _upload_unsupported_packages( f"-third-party-packages-from-anaconda-in-a-udf." ) finally: - if self._tmpdir_handler: - self._tmpdir_handler.cleanup() - self._tmpdir_handler = None + if tmpdir_handler: + tmpdir_handler.cleanup() return supported_dependencies + new_dependencies @@ -3094,7 +3149,9 @@ def _use_object(self, object_name: str, object_type: str) -> None: # we do not validate here object_type = match.group(1) object_name = match.group(2) - setattr(self._conn, f"_active_{object_type}", object_name) + mock_conn_lock = self._conn.get_lock() + with mock_conn_lock: + setattr(self._conn, f"_active_{object_type}", object_name) else: self._run_query(query) else: diff --git a/src/snowflake/snowpark/version.py b/src/snowflake/snowpark/version.py index 3955dbbbf33..798a3d902d0 100644 --- a/src/snowflake/snowpark/version.py +++ b/src/snowflake/snowpark/version.py @@ -4,4 +4,4 @@ # # Update this for the versions -VERSION = (1, 21, 1) +VERSION = (1, 22, 1) diff --git a/tests/integ/modin/conftest.py b/tests/integ/modin/conftest.py index 2f24954e769..a7217b38a50 100644 --- a/tests/integ/modin/conftest.py +++ b/tests/integ/modin/conftest.py @@ -715,3 +715,30 @@ def numeric_test_data_4x4(): "C": [7, 10, 13, 16], "D": [8, 11, 14, 17], } + + +@pytest.fixture +def timedelta_native_df() -> pandas.DataFrame: + return pandas.DataFrame( + { + "A": [ + pd.Timedelta(days=1), + pd.Timedelta(days=2), + pd.Timedelta(days=3), + pd.Timedelta(days=4), + ], + "B": [ + pd.Timedelta(minutes=-1), + pd.Timedelta(minutes=0), + pd.Timedelta(minutes=5), + pd.Timedelta(minutes=6), + ], + "C": [ + None, + pd.Timedelta(nanoseconds=5), + pd.Timedelta(nanoseconds=0), + pd.Timedelta(nanoseconds=4), + ], + "D": pandas.to_timedelta([pd.NaT] * 4), + } + ) diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index b018682b6f8..ba68ae13734 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -187,6 +187,108 @@ def test_string_sum_with_nulls(): assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"])) +class TestTimedelta: + """Test aggregating dataframes containing timedelta columns.""" + + @pytest.mark.parametrize( + "func, union_count", + [ + param( + lambda df: df.aggregate(["min"]), + 0, + id="aggregate_list_with_one_element", + ), + param(lambda df: df.aggregate(x=("A", "max")), 0, id="single_named_agg"), + # this works since all results are timedelta and we don't need to do any concats. + param( + lambda df: df.aggregate({"B": "mean", "A": "sum"}), + 0, + id="dict_producing_two_timedeltas", + ), + # this works since even though we need to do concats, all the results are non-timdelta. + param( + lambda df: df.aggregate(x=("B", "all"), y=("B", "any")), + 1, + id="named_agg_producing_two_bools", + ), + # note following aggregation requires transpose + param(lambda df: df.aggregate(max), 0, id="aggregate_max"), + param(lambda df: df.min(), 0, id="min"), + param(lambda df: df.max(), 0, id="max"), + param(lambda df: df.count(), 0, id="count"), + param(lambda df: df.sum(), 0, id="sum"), + param(lambda df: df.mean(), 0, id="mean"), + param(lambda df: df.median(), 0, id="median"), + param(lambda df: df.std(), 0, id="std"), + param(lambda df: df.quantile(), 0, id="single_quantile"), + param(lambda df: df.quantile([0.01, 0.99]), 1, id="two_quantiles"), + ], + ) + def test_supported_axis_0(self, func, union_count, timedelta_native_df): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + func, + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1653126") + def test_axis_1(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), lambda df: df.sum(axis=1) + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}), + lambda df: df.aggregate({"B": ["count"], "A": "sum", "C": ["max", "min"]}), + lambda df: df.aggregate( + x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count") + ), + lambda df: df.aggregate(["min", np.max]), + lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")), + lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")), + lambda df: df.aggregate( + {"B": ["idxmax"], "A": "sum", "C": ["max", "idxmin"]} + ), + ], + ) + def test_agg_requires_concat_with_timedelta(self, timedelta_native_df, operation): + eval_snowpark_pandas_result(*create_test_dfs(timedelta_native_df), operation) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires transposing a one-row frame with integer and timedelta.", + ) + def test_agg_produces_timedelta_and_non_timedelta_type(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.aggregate({"B": "idxmax", "A": "sum"}), + ) + + @pytest.mark.parametrize( "func, expected_union_count", [ diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index 1014cae44c9..ded0651046c 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -337,16 +337,6 @@ def f(x, y, z=1) -> int: class TestNotImplemented: - @pytest.mark.parametrize( - "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP - ) - @sql_count_checker(query_count=0) - def test_axis_0(self, data, func, return_type): - snow_df = pd.DataFrame(data) - msg = "Snowpark pandas apply API doesn't yet support axis == 0" - with pytest.raises(NotImplementedError, match=msg): - snow_df.apply(func) - @pytest.mark.parametrize("result_type", ["reduce", "expand", "broadcast"]) @sql_count_checker(query_count=0) def test_result_type(self, result_type): @@ -554,33 +544,70 @@ def g(v): ] -TRANSFORM_DATA_FUNC_MAP = [ - [[[0, 1, 2], [1, 2, 3]], lambda x: x + 1], - [[[0, 1, 2], [1, 2, 3]], np.exp], - [[[0, 1, 2], [1, 2, 3]], "exp"], - [[["Leonhard", "Jianzhun"]], lambda x: x + " is awesome!!"], - [[[1.3, 2.5]], np.sqrt], - [[[1.3, 2.5]], "sqrt"], - [[[1.3, 2.5]], np.log], - [[[1.3, 2.5]], "log"], - [[[1.3, 2.5]], np.square], - [[[1.3, 2.5]], "square"], +@pytest.mark.xfail( + strict=True, + raises=SnowparkSQLException, + reason="SNOW-1650918: Apply on dataframe data columns containing NULL fails with invalid arguments to udtf function", +) +@pytest.mark.parametrize( + "data, apply_func", [ - [[None, "abcd"]], - lambda x: x + " are first 4 letters of alphabet" if x is not None else None, + [ + [[None, "abcd"]], + lambda x: x + " are first 4 letters of alphabet" if x is not None else None, + ], + [ + [[123, None]], + lambda x: x + 100 if x is not None else None, + ], ], - [[[1.5, float("nan")]], lambda x: np.sqrt(x)], +) +def test_apply_bug_1650918(data, apply_func): + native_df = native_pd.DataFrame(data) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(apply_func, axis=1), + ) + + +TRANSFORM_TEST_MAP = [ + [[[0, 1, 2], [1, 2, 3]], lambda x: x + 1, 16], + [[[0, 1, 2], [1, 2, 3]], np.exp, 16], + [[[0, 1, 2], [1, 2, 3]], "exp", None], + [[["Leonhard", "Jianzhun"]], lambda x: x + " is awesome!!", 11], + [[[1.3, 2.5]], np.sqrt, 11], + [[[1.3, 2.5]], "sqrt", None], + [[[1.3, 2.5]], np.log, 11], + [[[1.3, 2.5]], "log", None], + [[[1.3, 2.5]], np.square, 11], + [[[1.3, 2.5]], "square", None], + [[[1.5, float("nan")]], lambda x: np.sqrt(x), 11], ] @pytest.mark.modin_sp_precommit -@pytest.mark.parametrize("data, apply_func", TRANSFORM_DATA_FUNC_MAP) -@sql_count_checker(query_count=0) -def test_basic_dataframe_transform(data, apply_func): - msg = "Snowpark pandas apply API doesn't yet support axis == 0" - with pytest.raises(NotImplementedError, match=msg): +@pytest.mark.parametrize("data, apply_func, expected_query_count", TRANSFORM_TEST_MAP) +def test_basic_dataframe_transform(data, apply_func, expected_query_count): + if expected_query_count is None: + msg = "Snowpark pandas apply API only supports callables func" + with SqlCounter(query_count=0): + with pytest.raises(NotImplementedError, match=msg): + snow_df = pd.DataFrame(data) + snow_df.transform(apply_func) + else: + msg = "SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function" + native_df = native_pd.DataFrame(data) snow_df = pd.DataFrame(data) - snow_df.transform(apply_func) + with SqlCounter( + query_count=expected_query_count, + high_count_expected=True, + high_count_reason=msg, + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.transform(apply_func) + ) AGGREGATION_FUNCTIONS = [ @@ -610,7 +637,7 @@ def test_dataframe_transform_invalid_function_name_negative(session): snow_df = pd.DataFrame([[0, 1, 2], [1, 2, 3]]) with pytest.raises( NotImplementedError, - match="Snowpark pandas apply API doesn't yet support axis == 0", + match="Snowpark pandas apply API only supports callables func", ): snow_df.transform("mxyzptlk") diff --git a/tests/integ/modin/frame/test_apply_axis_0.py b/tests/integ/modin/frame/test_apply_axis_0.py new file mode 100644 index 00000000000..47fd14d7b98 --- /dev/null +++ b/tests/integ/modin/frame/test_apply_axis_0.py @@ -0,0 +1,653 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import datetime + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark.exceptions import SnowparkSQLException +from tests.integ.modin.series.test_apply import create_func_with_return_type_hint +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.utils import ( + assert_snowpark_pandas_equal_to_pandas, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + create_test_dfs, + eval_snowpark_pandas_result, +) + +# test data which has a python type as return type that is not a pandas Series/pandas DataFrame/tuple/list +BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP = [ + [[[1.0, 2.2], [3, np.nan]], np.min, "float"], + [[[1.1, 2.2], [3, np.nan]], lambda x: x.sum(), "float"], + [[[1.1, 2.2], [3, np.nan]], lambda x: x.size, "int"], + [[[1.1, 2.2], [3, np.nan]], lambda x: "0" if x.sum() > 1 else 0, "object"], + [[["snow", "flake"], ["data", "cloud"]], lambda x: x[0] + x[1], "str"], + [[[True, False], [False, False]], lambda x: True, "bool"], + [[[True, False], [False, False]], lambda x: x[0] ^ x[1], "bool"], + ( + [ + [bytes("snow", "utf-8"), bytes("flake", "utf-8")], + [bytes("data", "utf-8"), bytes("cloud", "utf-8")], + ], + lambda x: (x[0] + x[1]).decode(), + "str", + ), + ( + [[["a", "b"], ["c", "d"]], [["a", "b"], ["c", "d"]]], + lambda x: x[0][1] + x[1][0], + "str", + ), + ( + [[{"a": "b"}, {"c": "d"}], [{"c": "b"}, {"a": "d"}]], + lambda x: str(x[0]) + str(x[1]), + "str", + ), +] + + +@pytest.mark.parametrize( + "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP +) +@pytest.mark.modin_sp_precommit +def test_axis_0_basic_types_without_type_hints(data, func, return_type): + # this test processes functions without type hints and invokes the UDTF solution. + native_df = native_pd.DataFrame(data, columns=["A", "b"]) + snow_df = pd.DataFrame(data, columns=["A", "b"]) + with SqlCounter( + query_count=11, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result(snow_df, native_df, lambda x: x.apply(func, axis=0)) + + +@pytest.mark.parametrize( + "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP +) +@pytest.mark.modin_sp_precommit +def test_axis_0_basic_types_with_type_hints(data, func, return_type): + # create explicitly for supported python types UDF with type hints and process via vUDF. + native_df = native_pd.DataFrame(data, columns=["A", "b"]) + snow_df = pd.DataFrame(data, columns=["A", "b"]) + func_with_type_hint = create_func_with_return_type_hint(func, return_type) + # Invoking a single UDF typically requires 3 queries (package management, code upload, UDF registration) upfront. + with SqlCounter( + query_count=11, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(func_with_type_hint, axis=0) + ) + + +@pytest.mark.parametrize( + "df,row_label", + [ + ( + native_pd.DataFrame( + [[1, 2], [None, 3]], columns=["A", "b"], index=["A", "B"] + ), + "B", + ), + ( + native_pd.DataFrame( + [[1, 2], [None, 3]], + columns=["A", "b"], + index=pd.MultiIndex.from_tuples([(1, 2), (1, 1)]), + ), + (1, 2), + ), + ], +) +def test_axis_0_index_passed_as_name(df, row_label): + # when using apply(axis=1) the original index of the dataframe is passed as name. + # test here for this for regular index and multi-index scenario. + + def foo(row) -> str: + if row.name == row_label: + return "MATCHING LABEL" + else: + return "NO MATCH" + + snow_df = pd.DataFrame(df) + with SqlCounter( + query_count=11, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=0)) + + +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_return_series(): + snow_df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"]) + native_df = native_pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"]) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(lambda x: native_pd.Series([1, 2], index=["C", "d"]), axis=0), + ) + + +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_return_series_with_different_label_results(): + df = native_pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"]) + snow_df = pd.DataFrame(df) + + eval_snowpark_pandas_result( + snow_df, + df, + lambda df: df.apply( + lambda x: native_pd.Series([1, 2], index=["a", "b"]) + if x.sum() > 3 + else native_pd.Series([0, 1, 2], index=["c", "a", "b"]), + axis=0, + ), + ) + + +@sql_count_checker(query_count=6, join_count=1, udtf_count=1) +def test_axis_0_return_single_scalar_series(): + native_df = native_pd.DataFrame([1]) + snow_df = pd.DataFrame(native_df) + + def apply_func(x): + return native_pd.Series([1], index=["xyz"]) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(apply_func, axis=0) + ) + + +@sql_count_checker(query_count=3) +def test_axis_0_return_dataframe_not_supported(): + snow_df = pd.DataFrame([1]) + + # Note that pands returns failure "ValueError: If using all scalar values, you must pass an index" which + # doesn't explain this isn't supported. We go with the default returned by pandas in this case. + with pytest.raises( + SnowparkSQLException, match="The truth value of a DataFrame is ambiguous." + ): + # return value + snow_df.apply(lambda x: native_pd.DataFrame([1, 2]), axis=0).to_pandas() + + +class TestNotImplemented: + @pytest.mark.parametrize("result_type", ["reduce", "expand", "broadcast"]) + @sql_count_checker(query_count=0) + def test_result_type(self, result_type): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + msg = "Snowpark pandas apply API doesn't yet support 'result_type' parameter" + with pytest.raises(NotImplementedError, match=msg): + snow_df.apply(lambda x: [1, 2], axis=0, result_type=result_type) + + @sql_count_checker(query_count=0) + def test_axis_1_apply_args_kwargs_with_snowpandas_object(self): + def f(x, y=None) -> native_pd.Series: + return x + (y if y is not None else 0) + + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + msg = "Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'" + with pytest.raises(NotImplementedError, match=msg): + snow_df.apply(f, axis=0, args=(pd.Series([1, 2]),)) + with pytest.raises(NotImplementedError, match=msg): + snow_df.apply(f, axis=0, y=pd.Series([1, 2])) + + +TEST_INDEX_1 = native_pd.MultiIndex.from_tuples( + list(zip(*[["a", "b"], ["x", "y"]])), + names=["first", "last"], +) + + +TEST_INDEX_WITH_NULL_1 = native_pd.MultiIndex.from_tuples( + list(zip(*[[None, "b"], ["x", None]])), + names=["first", "last"], +) + + +TEST_INDEX_2 = native_pd.MultiIndex.from_tuples( + list(zip(*[["AA", "BB"], ["XX", "YY"]])), + names=["FOO", "BAR"], +) + +TEST_INDEX_WITH_NULL_2 = native_pd.MultiIndex.from_tuples( + list(zip(*[[None, "BB"], ["XX", None]])), + names=["FOO", "BAR"], +) + + +TEST_COLUMNS_1 = native_pd.MultiIndex.from_tuples( + list( + zip( + *[ + ["car", "motorcycle", "bike", "bus"], + ["blue", "green", "red", "yellow"], + ] + ) + ), + names=["vehicle", "color"], +) + + +@pytest.mark.parametrize( + "apply_func, expected_join_count, expected_union_count", + [ + [lambda x: [1, 2], 3, 0], + [lambda x: x + 1 if x is not None else None, 3, 0], + [lambda x: x.min(), 2, 1], + ], +) +def test_axis_0_series_basic(apply_func, expected_join_count, expected_union_count): + native_df = native_pd.DataFrame( + [[1.1, 2.2], [3.0, None]], index=pd.Index([2, 3]), columns=["A", "b"] + ) + snow_df = pd.DataFrame(native_df) + with SqlCounter( + query_count=11, + join_count=expected_join_count, + udtf_count=2, + union_count=expected_union_count, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.apply(apply_func, axis=0), + ) + + +@sql_count_checker(query_count=5, join_count=1, udtf_count=1) +def test_groupby_apply_constant_output(): + native_df = native_pd.DataFrame([1, 2]) + native_df["fg"] = 0 + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by=["fg"], axis=0).apply(lambda x: [1, 2]), + ) + + +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_return_list(): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + native_df = native_pd.DataFrame([[1, 2], [3, 4]]) + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda x: [1, 2], axis=0) + ) + + +@pytest.mark.parametrize( + "apply_func", + [ + lambda x: -x, + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_1), + lambda x: native_pd.Series([3, 4], index=TEST_INDEX_2), + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_WITH_NULL_1), + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_WITH_NULL_1), + ], +) +@sql_count_checker( + query_count=21, + join_count=7, + udtf_count=4, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_multi_index_column_labels(apply_func): + data = [[i + j for j in range(0, 4)] for i in range(0, 4)] + + native_df = native_pd.DataFrame(data, columns=TEST_COLUMNS_1) + snow_df = pd.DataFrame(data, columns=TEST_COLUMNS_1) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(apply_func, axis=0) + ) + + +@sql_count_checker( + query_count=21, + join_count=7, + udtf_count=4, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_multi_index_column_labels_with_different_results(): + data = [[i + j for j in range(0, 4)] for i in range(0, 4)] + + df = native_pd.DataFrame(data, columns=TEST_COLUMNS_1) + snow_df = pd.DataFrame(df) + + apply_func = ( + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_1) + if min(x) == 0 + else native_pd.Series([3, 4], index=TEST_INDEX_2) + ) + + eval_snowpark_pandas_result(snow_df, df, lambda df: df.apply(apply_func, axis=0)) + + +@pytest.mark.parametrize( + "data, func, expected_result", + [ + [ + [ + [datetime.date(2023, 1, 1), None], + [datetime.date(2022, 12, 31), datetime.date(2021, 1, 9)], + ], + lambda x: x.dt.day, + native_pd.DataFrame([[1, np.nan], [31, 9.0]]), + ], + [ + [ + [datetime.time(1, 2, 3), None], + [datetime.time(1, 2, 3, 1), datetime.time(1)], + ], + lambda x: x.dt.seconds, + native_pd.DataFrame([[3723, np.nan], [3723, 3600]]), + ], + [ + [ + [datetime.datetime(2023, 1, 1, 1, 2, 3), None], + [ + datetime.datetime(2022, 12, 31, 1, 2, 3, 1), + datetime.datetime( + 2023, 1, 1, 1, 2, 3, tzinfo=datetime.timezone.utc + ), + ], + ], + lambda x: x.astype(str), + native_pd.DataFrame( + [ + ["2023-01-01 01:02:03.000000", "NaT"], + ["2022-12-31 01:02:03.000001", "2023-01-01 01:02:03+00:00"], + ] + ), + ], + ], +) +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_date_time_timestamp_type(data, func, expected_result): + snow_df = pd.DataFrame(data) + result = snow_df.apply(func, axis=0) + + assert_snowpark_pandas_equal_to_pandas(result, expected_result) + + +@pytest.mark.parametrize( + "native_df, func", + [ + ( + native_pd.DataFrame([[1, 2], [3, 4]], index=["a", "b"]), + lambda x: x["a"] + x["b"], + ), + ( + native_pd.DataFrame( + [[1, 5], [2, 6], [3, 7], [4, 8]], + index=native_pd.MultiIndex.from_tuples( + [("baz", "A"), ("baz", "B"), ("zoo", "A"), ("zoo", "B")] + ), + ), + lambda x: x["baz", "B"] * x["zoo", "A"], + ), + ], +) +@sql_count_checker( + query_count=11, + join_count=2, + udtf_count=2, + union_count=1, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_index_labels(native_df, func): + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result(snow_df, native_df, lambda x: x.apply(func, axis=0)) + + +@sql_count_checker( + query_count=11, + join_count=2, + udtf_count=2, + union_count=1, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_raw(): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + native_df = native_pd.DataFrame([[1, 2], [3, 4]]) + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda x: str(type(x)), axis=0, raw=True) + ) + + +def test_axis_0_apply_args_kwargs(): + def f(x, y, z=1) -> int: + return x.sum() + y + z + + native_df = native_pd.DataFrame([[1, 2], [3, 4]]) + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + + with SqlCounter(query_count=3): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(f, axis=0), + expect_exception=True, + expect_exception_type=SnowparkSQLException, + expect_exception_match="missing 1 required positional argument", + assert_exception_equal=False, + ) + + with SqlCounter( + query_count=11, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(f, axis=0, args=(1,)) + ) + + with SqlCounter( + query_count=11, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(f, axis=0, args=(1,), z=2) + ) + + with SqlCounter(query_count=3): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(f, axis=0, args=(1,), z=2, v=3), + expect_exception=True, + expect_exception_type=SnowparkSQLException, + expect_exception_match="got an unexpected keyword argument", + assert_exception_equal=False, + ) + + +@pytest.mark.parametrize("data", [{"a": [1], "b": [2]}, {"a": [2], "b": [3]}]) +def test_apply_axis_0_with_if_where_duplicates_not_executed(data): + df = native_pd.DataFrame(data) + snow_df = pd.DataFrame(df) + + def foo(x): + return native_pd.Series( + [1, 2, 3], index=["C", "A", "E"] if x.sum() > 3 else ["A", "E", "E"] + ) + + with SqlCounter( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=0)) + + +@pytest.mark.parametrize( + "return_value", + [ + native_pd.Series(["a", np.int64(3)]), + ["a", np.int64(3)], + np.int64(3), + ], +) +@sql_count_checker(query_count=6, join_count=1, udtf_count=1) +def test_numpy_integers_in_return_values_snow_1227264(return_value): + eval_snowpark_pandas_result( + *create_test_dfs(["a"]), lambda df: df.apply(lambda row: return_value, axis=0) + ) + + +@pytest.mark.xfail( + strict=True, + raises=SnowparkSQLException, + reason="SNOW-1650918: Apply on dataframe data columns containing NULL fails with invalid arguments to udtf function", +) +@pytest.mark.parametrize( + "data, apply_func", + [ + [ + [[None, "abcd"]], + lambda x: x + " are first 4 letters of alphabet" if x is not None else None, + ], + [ + [[123, None]], + lambda x: x + 100 if x is not None else None, + ], + ], +) +def test_apply_axis_0_bug_1650918(data, apply_func): + native_df = native_pd.DataFrame(data) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(apply_func, axis=0), + ) + + +def test_apply_nested_series_negative(): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + + with SqlCounter( + query_count=10, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + with pytest.raises( + NotImplementedError, + match=r"Nested pd.Series in result is not supported in DataFrame.apply\(axis=0\)", + ): + snow_df.apply( + lambda ser: 99 if ser.sum() == 4 else native_pd.Series([1, 2]), axis=0 + ).to_pandas() + + snow_df2 = pd.DataFrame([[1, 2, 3]]) + + with SqlCounter( + query_count=15, + join_count=3, + udtf_count=3, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + with pytest.raises( + NotImplementedError, + match=r"Nested pd.Series in result is not supported in DataFrame.apply\(axis=0\)", + ): + snow_df2.apply( + lambda ser: 99 + if ser.sum() == 2 + else native_pd.Series([100], index=["a"]), + axis=0, + ).to_pandas() + + +import scipy.stats # noqa: E402 + + +@pytest.mark.parametrize( + "packages,expected_query_count", + [ + (["scipy", "numpy"], 26), + (["scipy>1.1", "numpy<2.0"], 26), + # TODO: SNOW-1478188 Re-enable quarantined tests for 8.23 + # [scipy, np], 9), + ], +) +def test_apply_axis0_with_3rd_party_libraries_and_decorator( + packages, expected_query_count +): + data = [[1, 2, 3, 4, 5], [7, -20, 4.0, 7.0, None]] + + with SqlCounter( + query_count=expected_query_count, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + try: + pd.session.custom_package_usage_config["enabled"] = True + pd.session.add_packages(packages) + + df = pd.DataFrame(data) + + def func(row): + return np.dot(row, scipy.stats.norm.pdf(row)) + + snow_ans = df.apply(func, axis=0) + finally: + pd.session.clear_packages() + pd.session.clear_imports() + + # same in native pandas: + native_df = native_pd.DataFrame(data) + native_ans = native_df.apply(func, axis=0) + + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(snow_ans, native_ans) diff --git a/tests/integ/modin/frame/test_describe.py b/tests/integ/modin/frame/test_describe.py index a9668c5794f..4f1882d441d 100644 --- a/tests/integ/modin/frame/test_describe.py +++ b/tests/integ/modin/frame/test_describe.py @@ -358,3 +358,18 @@ def test_describe_object_file(resources_path): df = pd.read_csv(test_files.test_concat_file1_csv) native_df = df.to_pandas() eval_snowpark_pandas_result(df, native_df, lambda x: x.describe(include="O")) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df.describe(), + ) diff --git a/tests/integ/modin/frame/test_dtypes.py b/tests/integ/modin/frame/test_dtypes.py index c3773bdd6de..b078b31f6c5 100644 --- a/tests/integ/modin/frame/test_dtypes.py +++ b/tests/integ/modin/frame/test_dtypes.py @@ -351,7 +351,7 @@ def test_insert_multiindex_multi_label(label1, label2): native_pd.Timestamp(1513393355, unit="s", tz="US/Pacific"), ], "datetime64[ns, America/Los_Angeles]", - "datetime64[ns, America/Los_Angeles]", + "datetime64[ns, UTC-08:00]", "datetime64[ns]", ), ( @@ -372,7 +372,7 @@ def test_insert_multiindex_multi_label(label1, label2): native_pd.Timestamp(1513393355, unit="s", tz="US/Pacific"), ], "object", - "datetime64[ns, America/Los_Angeles]", + "datetime64[ns, UTC-08:00]", "datetime64[ns]", ), ], diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py index 72fe88968bc..87041060bd2 100644 --- a/tests/integ/modin/frame/test_idxmax_idxmin.py +++ b/tests/integ/modin/frame/test_idxmax_idxmin.py @@ -196,8 +196,18 @@ def test_idxmax_idxmin_with_dates(func, axis): @sql_count_checker(query_count=1) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) -@pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.xfail(reason="SNOW-1625380 TODO") +@pytest.mark.parametrize( + "axis", + [ + 0, + pytest.param( + 1, + marks=pytest.mark.xfail( + strict=True, raises=NotImplementedError, reason="SNOW-1653126" + ), + ), + ], +) def test_idxmax_idxmin_with_timedelta(func, axis): native_df = native_pd.DataFrame( data={ diff --git a/tests/integ/modin/frame/test_info.py b/tests/integ/modin/frame/test_info.py index 2a096e76fdc..fbbf8dfe041 100644 --- a/tests/integ/modin/frame/test_info.py +++ b/tests/integ/modin/frame/test_info.py @@ -13,9 +13,7 @@ def _assert_info_lines_equal(modin_info: list[str], pandas_info: list[str]): # class is different - assert ( - modin_info[0] == "" - ) + assert modin_info[0] == "" assert pandas_info[0] == "" # index is different diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py index be51b8c9ae6..105bf475f3a 100644 --- a/tests/integ/modin/frame/test_loc.py +++ b/tests/integ/modin/frame/test_loc.py @@ -4072,3 +4072,22 @@ def test_df_loc_get_with_timedelta_and_none_key(): # Compare with an empty DataFrame, since native pandas raises a KeyError. expected_df = native_pd.DataFrame() assert_frame_equal(snow_df.loc[None], expected_df, check_column_type=False) + + +@sql_count_checker(query_count=0) +def test_df_loc_invalid_key(): + # Bug fix: SNOW-1320674 + native_df = native_pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + snow_df = pd.DataFrame(native_df) + + def op(df): + df["C"] = df["A"] / df["D"] + + eval_snowpark_pandas_result( + snow_df, + native_df, + op, + expect_exception=True, + expect_exception_type=KeyError, + expect_exception_match="D", + ) diff --git a/tests/integ/modin/frame/test_nunique.py b/tests/integ/modin/frame/test_nunique.py index d0cad8ec2ad..78098d34386 100644 --- a/tests/integ/modin/frame/test_nunique.py +++ b/tests/integ/modin/frame/test_nunique.py @@ -11,8 +11,13 @@ from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result -TEST_LABELS = np.array(["A", "B", "C", "D"]) -TEST_DATA = [[0, 1, 2, 3], [0, 0, 0, 0], [None, 0, None, 0], [None, None, None, None]] +TEST_LABELS = np.array(["A", "B", "C", "D", "E"]) +TEST_DATA = [ + [0, 1, 2, 3, pd.Timedelta(4)], + [0, 0, 0, 0, pd.Timedelta(0)], + [None, 0, None, 0, pd.Timedelta(0)], + [None, None, None, None, None], +] # which original dataframe (constructed from slicing) to test for TEST_SLICES = [ @@ -80,7 +85,7 @@ def test_dataframe_nunique_no_columns(native_df): [ pytest.param(None, id="default_columns"), pytest.param( - [["bar", "bar", "baz", "foo"], ["one", "two", "one", "two"]], + [["bar", "bar", "baz", "foo", "foo"], ["one", "two", "one", "two", "one"]], id="2D_columns", ), ], diff --git a/tests/integ/modin/frame/test_skew.py b/tests/integ/modin/frame/test_skew.py index 72fad6cebdc..94b7fd79c24 100644 --- a/tests/integ/modin/frame/test_skew.py +++ b/tests/integ/modin/frame/test_skew.py @@ -8,7 +8,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import assert_series_equal +from tests.integ.modin.utils import ( + assert_series_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @sql_count_checker(query_count=1) @@ -62,16 +66,22 @@ def test_skew_basic(): }, "kwargs": {"numeric_only": True, "skipna": True}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": True, + }, + }, ], ) @sql_count_checker(query_count=1) def test_skew(data): - native_df = native_pd.DataFrame(data["frame"]) - snow_df = pd.DataFrame(native_df) - assert_series_equal( - snow_df.skew(**data["kwargs"]), - native_df.skew(**data["kwargs"]), - rtol=1.0e-5, + eval_snowpark_pandas_result( + *create_test_dfs(data["frame"]), + lambda df: df.skew(**data["kwargs"]), + rtol=1.0e-5 ) @@ -103,6 +113,14 @@ def test_skew(data): }, "kwargs": {"level": 2}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": False, + }, + }, ], ) @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py index d5234dfbdb5..df8df44d47c 100644 --- a/tests/integ/modin/groupby/test_all_any.py +++ b/tests/integ/modin/groupby/test_all_any.py @@ -14,7 +14,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_frame_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @pytest.mark.parametrize( @@ -109,3 +113,27 @@ def test_all_any_chained(): lambda df: df.apply(lambda ser: ser.str.len()) ) ) + + +@sql_count_checker(query_count=1) +def test_timedelta_any_with_nulls(): + """ + Test this case separately because pandas behavior is different from Snowpark pandas behavior. + + pandas bug that does not apply to Snowpark pandas: + https://github.com/pandas-dev/pandas/issues/59712 + """ + snow_df, native_df = create_test_dfs( + { + "key": ["a"], + "A": native_pd.Series([pd.NaT], dtype="timedelta64[ns]"), + }, + ) + assert_frame_equal( + native_df.groupby("key").any(), + native_pd.DataFrame({"A": [True]}, index=native_pd.Index(["a"], name="key")), + ) + assert_frame_equal( + snow_df.groupby("key").any(), + native_pd.DataFrame({"A": [False]}, index=native_pd.Index(["a"], name="key")), + ) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 09acd49bb21..cbf5b75d48c 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1096,60 +1096,81 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): ) -@pytest.mark.parametrize( - "agg_func", - [ - "count", - "sum", - "mean", - "median", - "std", - ], -) -@pytest.mark.parametrize("by", ["A", "B"]) -@sql_count_checker(query_count=1) -def test_timedelta(agg_func, by): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - - eval_snowpark_pandas_result( - snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() - ) - - -def test_timedelta_groupby_agg(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - "C": [True, False, False, True], - } +class TestTimedelta: + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "method", + [ + "count", + "mean", + "min", + "max", + "idxmax", + "idxmin", + "sum", + "median", + "std", + "nunique", + ], ) - snow_df = pd.DataFrame(native_df) - with SqlCounter(query_count=1): + @pytest.mark.parametrize("by", ["A", "B"]) + def test_aggregation_methods(self, method, by): eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: getattr(df.groupby(by), method)(), ) - with SqlCounter(query_count=1): - eval_snowpark_pandas_result( - snow_df, - native_df, + + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}), + lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + lambda df: df.groupby("B").agg(["mean", "std"]), + lambda df: df.groupby("B").agg({"A": ["count", np.sum]}), + lambda df: df.groupby("B").agg({"A": "sum"}), + ], + ) + def test_agg(self, operation): + eval_snowpark_pandas_result( + *create_test_dfs( + native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + "C": [True, False, False, True], + } + ) + ), + operation, ) - with SqlCounter(query_count=1): + + @sql_count_checker(query_count=1) + def test_groupby_timedelta_var(self): + """ + Test that we can group by a timedelta column and take var() of an integer column. + + Note that we can't take the groupby().var() of the timedelta column because + var() is not defined for timedelta, in pandas or in Snowpark pandas. + """ eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: df.groupby("A").var(), ) diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py index 5da35806dd1..5e04d5a6fc2 100644 --- a/tests/integ/modin/groupby/test_groupby_first_last.py +++ b/tests/integ/modin/groupby/test_groupby_first_last.py @@ -46,6 +46,17 @@ [np.nan], ] ), + "col11_timedelta": [ + pd.Timedelta("1 days"), + None, + pd.Timedelta("2 days"), + None, + None, + None, + None, + None, + None, + ], } diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py index a009e1089b0..0c9c056c2a7 100644 --- a/tests/integ/modin/groupby/test_groupby_negative.py +++ b/tests/integ/modin/groupby/test_groupby_negative.py @@ -18,6 +18,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + create_test_dfs, eval_snowpark_pandas_result, ) @@ -559,20 +560,12 @@ def test_groupby_agg_invalid_min_count( @sql_count_checker(query_count=0) -def test_groupby_var_no_support_for_timedelta(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - with pytest.raises( - NotImplementedError, - match=re.escape( - "SnowflakeQueryCompiler::groupby_agg is not yet implemented for Timedelta Type" +def test_timedelta_var_invalid(): + eval_snowpark_pandas_result( + *create_test_dfs( + [["key0", pd.Timedelta(1)]], ), - ): - snow_df.groupby("B").var() + lambda df: df.groupby(0).var(), + expect_exception=True, + expect_exception_type=TypeError, + ) diff --git a/tests/integ/modin/groupby/test_quantile.py b/tests/integ/modin/groupby/test_quantile.py index b14299fee63..940d366a7e2 100644 --- a/tests/integ/modin/groupby/test_quantile.py +++ b/tests/integ/modin/groupby/test_quantile.py @@ -64,6 +64,14 @@ # ), # All NA ([np.nan] * 5, [np.nan] * 5), + pytest.param( + pd.timedelta_range( + "1 days", + "5 days", + ), + pd.timedelta_range("1 second", "5 second"), + id="timedelta", + ), ], ) @pytest.mark.parametrize("q", [0, 0.5, 1]) diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index 3c6362dd83c..26afd232c4f 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -33,11 +33,11 @@ native_pd.Index(["a", "b", "c", "d"]), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"], - tz="America/Los_Angeles", + tz="UTC-08:00", ), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+05:00", "2020-02-01 11:00:00+05:00"], - tz="America/Los_Angeles", + tz="UTC", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), native_pd.TimedeltaIndex(["0 days", "1 days", "3 days"]), @@ -55,11 +55,11 @@ native_pd.Index(["a", "b", 1, 2, None, "a", 2], name="mixed index"), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"], - tz="America/Los_Angeles", + tz="UTC", ), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+00:00", "2020-01-01 10:00:00+00:00"], - tz="America/Los_Angeles", + tz="UTC-08:00", ), ] @@ -79,4 +79,5 @@ tz="America/Los_Angeles", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), + native_pd.TimedeltaIndex(["4 days", None, "-1 days", "5 days"]), ] diff --git a/tests/integ/modin/index/test_all_any.py b/tests/integ/modin/index/test_all_any.py index 267e7929ea1..499be6f03dc 100644 --- a/tests/integ/modin/index/test_all_any.py +++ b/tests/integ/modin/index/test_all_any.py @@ -25,6 +25,9 @@ native_pd.Index(["a", "b", "c", "d"]), native_pd.Index([5, None, 7]), native_pd.Index([], dtype="object"), + native_pd.Index([pd.Timedelta(0), None]), + native_pd.Index([pd.Timedelta(0)]), + native_pd.Index([pd.Timedelta(0), pd.Timedelta(1)]), ] NATIVE_INDEX_EMPTY_DATA = [ diff --git a/tests/integ/modin/index/test_argmax_argmin.py b/tests/integ/modin/index/test_argmax_argmin.py index 6d446a0a66a..7d42f3b88c9 100644 --- a/tests/integ/modin/index/test_argmax_argmin.py +++ b/tests/integ/modin/index/test_argmax_argmin.py @@ -18,6 +18,18 @@ native_pd.Index([4, None, 1, 3, 4, 1]), native_pd.Index([4, None, 1, 3, 4, 1], name="some name"), native_pd.Index([1, 10, 4, 3, 4]), + pytest.param( + native_pd.Index( + [ + pd.Timedelta(1), + pd.Timedelta(10), + pd.Timedelta(4), + pd.Timedelta(3), + pd.Timedelta(4), + ] + ), + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 56fd40a6cb3..98d1a041c3b 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -4,8 +4,10 @@ import re import modin.pandas as pd +import numpy as np import pandas as native_pd import pytest +import pytz import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker @@ -16,6 +18,46 @@ eval_snowpark_pandas_result, ) +timezones = pytest.mark.parametrize( + "tz", + [ + None, + # Use a subset of pytz.common_timezones containing a few timezones in each + *[ + param_for_one_tz + for tz in [ + "Africa/Abidjan", + "Africa/Timbuktu", + "America/Adak", + "America/Yellowknife", + "Antarctica/Casey", + "Asia/Dhaka", + "Asia/Manila", + "Asia/Shanghai", + "Atlantic/Stanley", + "Australia/Sydney", + "Canada/Pacific", + "Europe/Chisinau", + "Europe/Luxembourg", + "Indian/Christmas", + "Pacific/Chatham", + "Pacific/Wake", + "US/Arizona", + "US/Central", + "US/Eastern", + "US/Hawaii", + "US/Mountain", + "US/Pacific", + "UTC", + ] + for param_for_one_tz in ( + pytz.timezone(tz), + tz, + ) + ], + ], +) + @sql_count_checker(query_count=0) def test_datetime_index_construction(): @@ -100,13 +142,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame({"A": [1]}, index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) @@ -232,6 +274,76 @@ def test_normalize(): ) +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_convert(tz): + native_index = native_pd.date_range( + start="2021-01-01", periods=5, freq="7h", tz="US/Eastern" + ) + native_index = native_index.append( + native_pd.DatetimeIndex([pd.NaT], tz="US/Eastern") + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_convert(tz).equals( + pd.DatetimeIndex(native_index.tz_convert(tz)) + ) + + +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_localize(tz): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_localize(tz).equals( + pd.DatetimeIndex(native_index.tz_localize(tz)) + ) + + +@pytest.mark.parametrize( + "ambiguous, nonexistent", + [ + ("infer", "raise"), + ("NaT", "raise"), + (np.array([True, True, False]), "raise"), + ("raise", "shift_forward"), + ("raise", "shift_backward"), + ("raise", "NaT"), + ("raise", pd.Timedelta("1h")), + ("infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_tz_localize_negative(ambiguous, nonexistent): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + with pytest.raises(NotImplementedError): + snow_index.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent) + + @pytest.mark.parametrize( "datetime_index_value", [ @@ -268,7 +380,12 @@ def test_floor_ceil_round(datetime_index_value, func, freq): [ ("1w", "raise", "raise"), ("1h", "infer", "raise"), + ("1h", "NaT", "raise"), + ("1h", np.array([True, True, False]), "raise"), ("1h", "raise", "shift_forward"), + ("1h", "raise", "shift_backward"), + ("1h", "raise", "NaT"), + ("1h", "raise", pd.Timedelta("1h")), ("1w", "infer", "shift_forward"), ], ) diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 8d0434915ac..6b33eb89889 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -393,13 +393,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py index b916110f386..f915598c5f6 100644 --- a/tests/integ/modin/index/test_name.py +++ b/tests/integ/modin/index/test_name.py @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index(): ), inplace=True, ) + + +@sql_count_checker(query_count=1) +def test_index_names_replace_behavior(): + """ + Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change the names. + snow_index.name = "test2" + native_index.name = "test2" + + # Compare the names. + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the query compiler the DataFrame is referring to, change the names. + snow_df.dropna(inplace=True) + native_df.dropna(inplace=True) + snow_index.name = "test3" + native_index.name = "test3" + + # Compare the names. Changing the index name should not change the DataFrame's index name. + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test2" + + +@sql_count_checker(query_count=1) +def test_index_names_multiple_renames(): + """ + Check that the index name of a DataFrame can be renamed any number of times. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change and compare the names. + snow_index.name = "test2" + native_index.name = "test2" + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the names again and compare. + snow_index.name = "test3" + native_index.name = "test3" + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test3" diff --git a/tests/integ/modin/index/test_timedelta_index_methods.py b/tests/integ/modin/index/test_timedelta_index_methods.py index 25bef5364f2..c4d4a0b3a66 100644 --- a/tests/integ/modin/index/test_timedelta_index_methods.py +++ b/tests/integ/modin/index/test_timedelta_index_methods.py @@ -128,3 +128,29 @@ def test_timedelta_total_seconds(): native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA) snow_index = pd.Index(native_index) eval_snowpark_pandas_result(snow_index, native_index, lambda x: x.total_seconds()) + + +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.parametrize("data", [[1, 2, 3], [1, 2, 3, None], [None], []]) +@sql_count_checker(query_count=1) +def test_timedelta_index_mean(skipna, data): + native_index = native_pd.TimedeltaIndex(data) + snow_index = pd.Index(native_index) + native_result = native_index.mean(skipna=skipna) + snow_result = snow_index.mean(skipna=skipna) + # Special check for NaN because Nan != Nan. + if pd.isna(native_result): + assert pd.isna(snow_result) + else: + assert snow_result == native_result + + +@sql_count_checker(query_count=0) +def test_timedelta_index_mean_invalid_axis(): + native_index = native_pd.TimedeltaIndex([1, 2, 3]) + snow_index = pd.Index(native_index) + with pytest.raises(IndexError, match="tuple index out of range"): + native_index.mean(axis=1) + # Snowpark pandas raises ValueError instead of IndexError. + with pytest.raises(ValueError, match="axis should be 0 for TimedeltaIndex.mean"): + snow_index.mean(axis=1).to_pandas() diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py index fa354fda1fc..c3e40828d94 100644 --- a/tests/integ/modin/series/test_aggregate.py +++ b/tests/integ/modin/series/test_aggregate.py @@ -1,6 +1,8 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import re + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -17,6 +19,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_series, eval_snowpark_pandas_result, ) @@ -358,3 +361,67 @@ def test_2_tuple_named_agg_errors_for_series(native_series, agg_kwargs): expect_exception_type=SpecificationError, assert_exception_equal=True, ) + + +class TestTimedelta: + """Test aggregating a timedelta series.""" + + @pytest.mark.parametrize( + "func, union_count, is_scalar", + [ + pytest.param(*v, id=str(i)) + for i, v in enumerate( + [ + (lambda series: series.aggregate(["min"]), 0, False), + (lambda series: series.aggregate({"A": "max"}), 0, False), + # this works since even though we need to do concats, all the results are non-timdelta. + (lambda df: df.aggregate(["all", "any", "count"]), 2, False), + # note following aggregation requires transpose + (lambda df: df.aggregate(max), 0, True), + (lambda df: df.min(), 0, True), + (lambda df: df.max(), 0, True), + (lambda df: df.count(), 0, True), + (lambda df: df.sum(), 0, True), + (lambda df: df.mean(), 0, True), + (lambda df: df.median(), 0, True), + (lambda df: df.std(), 0, True), + (lambda df: df.quantile(), 0, True), + (lambda df: df.quantile([0.01, 0.99]), 0, False), + ] + ) + ], + ) + def test_supported(self, func, union_count, timedelta_native_df, is_scalar): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + func, + comparator=validate_scalar_result + if is_scalar + else assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda series: series.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + def test_unsupported_due_to_concat(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda df: df.agg(["count", "max"]), + ) diff --git a/tests/integ/modin/series/test_argmax_argmin.py b/tests/integ/modin/series/test_argmax_argmin.py index 607b36a27f3..e212e3ba2dd 100644 --- a/tests/integ/modin/series/test_argmax_argmin.py +++ b/tests/integ/modin/series/test_argmax_argmin.py @@ -18,6 +18,11 @@ ([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]), ([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/series/test_astype.py b/tests/integ/modin/series/test_astype.py index 5bbce79b01b..030416d65c5 100644 --- a/tests/integ/modin/series/test_astype.py +++ b/tests/integ/modin/series/test_astype.py @@ -173,6 +173,11 @@ def test_astype_basic(from_dtype, to_dtype): ) def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): to_dtype = f"datetime64[ns, {to_tz}]" + offset_map = { + "UTC": "UTC", + "Asia/Tokyo": "UTC+09:00", + "America/Los_Angeles": "UTC-08:00", + } seed = ( [True, False, False, True] # if isinstance(from_dtype, BooleanDtype) @@ -189,23 +194,22 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): native_pd.Series(seed, dtype=from_dtype).astype(to_dtype) elif isinstance(from_dtype, StringDtype) or from_dtype is str: # Snowpark pandas use Snowflake auto format detection and the behavior can be different from native pandas - # to_pandas always convert timezone to the local timezone today, i.e., "America/Los_angeles" with SqlCounter(query_count=1): assert_snowpark_pandas_equal_to_pandas( pd.Series(seed, dtype=from_dtype).astype(to_dtype), native_pd.Series( [ native_pd.Timestamp("1970-01-01 00:00:00", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), native_pd.Timestamp("1970-01-01 00:00:01", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), native_pd.Timestamp("1970-01-01 00:00:02", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), native_pd.Timestamp("1970-01-01 00:00:03", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), ] ), @@ -251,15 +255,15 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): ): native_pd.Series(seed, dtype=from_dtype).astype(to_dtype) expected_to_pandas = ( - native_pd.Series(seed, dtype=from_dtype).dt.tz_localize("UTC") - # Snowpark pandas to_pandas() will convert timestamp_tz to default local timezone - .dt.tz_convert("America/Los_Angeles") + native_pd.Series(seed, dtype=from_dtype) + .dt.tz_localize("UTC") + .dt.tz_convert(offset_map[to_tz]) ) else: expected_to_pandas = ( - native_pd.Series(seed, dtype=from_dtype).astype(to_dtype) - # Snowpark pandas to_pandas() will convert timestamp_tz to default local timezone - .dt.tz_convert("America/Los_Angeles") + native_pd.Series(seed, dtype=from_dtype) + .astype(to_dtype) + .dt.tz_convert(offset_map[to_tz]) ) assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( s, @@ -392,11 +396,7 @@ def test_python_datetime_astype_DatetimeTZDtype(seed): with SqlCounter(query_count=1): snow = s.astype(to_dtype) assert snow.dtype == np.dtype(" native_pd.Series: return native_pd.Series( @@ -140,7 +183,12 @@ def test_floor_ceil_round(datetime_index_value, func, freq): [ ("1w", "raise", "raise"), ("1h", "infer", "raise"), + ("1h", "NaT", "raise"), + ("1h", np.array([True, True, False]), "raise"), ("1h", "raise", "shift_forward"), + ("1h", "raise", "shift_backward"), + ("1h", "raise", "NaT"), + ("1h", "raise", pd.Timedelta("1h")), ("1w", "infer", "shift_forward"), ], ) @@ -174,6 +222,79 @@ def test_normalize(): ) +@sql_count_checker(query_count=1) +@timezones +def test_tz_convert(tz): + datetime_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + tz="US/Eastern", + ) + native_ser = native_pd.Series(datetime_index) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda s: s.dt.tz_convert(tz), + ) + + +@sql_count_checker(query_count=1) +@timezones +def test_tz_localize(tz): + datetime_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + native_ser = native_pd.Series(datetime_index) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda s: s.dt.tz_localize(tz), + ) + + +@pytest.mark.parametrize( + "ambiguous, nonexistent", + [ + ("infer", "raise"), + ("NaT", "raise"), + (np.array([True, True, False]), "raise"), + ("raise", "shift_forward"), + ("raise", "shift_backward"), + ("raise", "NaT"), + ("raise", pd.Timedelta("1h")), + ("infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_tz_localize_negative(ambiguous, nonexistent): + datetime_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + native_ser = native_pd.Series(datetime_index) + snow_ser = pd.Series(native_ser) + with pytest.raises(NotImplementedError): + snow_ser.dt.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent) + + @pytest.mark.parametrize("name", [None, "hello"]) def test_isocalendar(name): with SqlCounter(query_count=1): diff --git a/tests/integ/modin/series/test_first_last_valid_index.py b/tests/integ/modin/series/test_first_last_valid_index.py index 1e8d052e10f..1930bdf1088 100644 --- a/tests/integ/modin/series/test_first_last_valid_index.py +++ b/tests/integ/modin/series/test_first_last_valid_index.py @@ -22,6 +22,10 @@ native_pd.Series([5, 6, 7, 8], index=["i", "am", "iron", "man"]), native_pd.Series([None, None, 2], index=[None, 1, 2]), native_pd.Series([None, None, 2], index=[None, None, None]), + pytest.param( + native_pd.Series([None, None, pd.Timedelta(2)], index=[None, 1, 2]), + id="timedelta", + ), ], ) def test_first_and_last_valid_index_series(native_series): diff --git a/tests/integ/modin/series/test_idxmax_idxmin.py b/tests/integ/modin/series/test_idxmax_idxmin.py index ea536240a42..e8e66a30f61 100644 --- a/tests/integ/modin/series/test_idxmax_idxmin.py +++ b/tests/integ/modin/series/test_idxmax_idxmin.py @@ -17,6 +17,11 @@ ([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]), ([1, None, 4, 3, 4], [None, "B", "C", "D", "E"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) diff --git a/tests/integ/modin/series/test_nunique.py b/tests/integ/modin/series/test_nunique.py index bb20e9e4a53..3856dbc516a 100644 --- a/tests/integ/modin/series/test_nunique.py +++ b/tests/integ/modin/series/test_nunique.py @@ -6,6 +6,7 @@ import numpy as np import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker @@ -32,6 +33,20 @@ [True, None, False, True, None], [1.1, "a", None] * 4, [native_pd.to_datetime("2023-12-01"), native_pd.to_datetime("1999-09-09")] * 2, + param( + [ + native_pd.Timedelta(1), + native_pd.Timedelta(1), + native_pd.Timedelta(2), + None, + None, + ], + id="timedelta_with_nulls", + ), + param( + [native_pd.Timedelta(1), native_pd.Timedelta(1), native_pd.Timedelta(2)], + id="timedelta_without_nulls", + ), ], ) @pytest.mark.parametrize("dropna", [True, False]) diff --git a/tests/integ/modin/test_classes.py b/tests/integ/modin/test_classes.py index c92bb85c531..6e6c2eda8eb 100644 --- a/tests/integ/modin/test_classes.py +++ b/tests/integ/modin/test_classes.py @@ -34,14 +34,14 @@ def test_class_names_constructors(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) s = pd.Series(index=[1, 2, 3], data=[3, 2, 1]) expect_type_check( s, pd.Series, - "snowflake.snowpark.modin.pandas.series.Series", + "modin.pandas.series.Series", ) @@ -63,7 +63,7 @@ def test_op(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) @@ -77,7 +77,7 @@ def test_native_conversion(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) # Snowpark pandas -> native pandas diff --git a/tests/integ/modin/test_dtype_mapping.py b/tests/integ/modin/test_dtype_mapping.py index 868a37ff22d..2e474c2aec4 100644 --- a/tests/integ/modin/test_dtype_mapping.py +++ b/tests/integ/modin/test_dtype_mapping.py @@ -281,15 +281,11 @@ "timestamp_tz timestamp_tz", "values ('2023-01-01 00:00:01.001 +0000'), ('2023-12-31 23:59:59.999 +1000')", # timestamp_tz only supports tz offset dtype(" from_pandas => TIMESTAMP_TZ(any_tz) => to_pandas => DatetimeTZDtype(session_tz) - # - # Note that python connector will convert any TIMESTAMP_TZ to DatetimeTZDtype with the current session/statement - # timezone, e.g., 1969-12-31 19:00:00 -0500 will be converted to 1970-00-01 00:00:00 in UTC if the session/statement - # parameter TIMEZONE = 'UTC' - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - try: - session.sql(f"alter session set timezone = '{timezone}'").collect() - - def get_series_with_tz(tz): - return ( - native_pd.Series([1] * 3) - .astype("int64") - .astype(f"datetime64[ns, {tz}]") - ) +@sql_count_checker(query_count=1) +def test_from_to_pandas_datetime64_timezone_support(): + def get_series_with_tz(tz): + return native_pd.Series([1] * 3).astype("int64").astype(f"datetime64[ns, {tz}]") - # same timestamps representing in different time zone - test_data_columns = { - "utc": get_series_with_tz("UTC"), - "pacific": get_series_with_tz("US/Pacific"), - "tokyo": get_series_with_tz("Asia/Tokyo"), - } + # same timestamps representing in different time zone + test_data_columns = { + "utc": get_series_with_tz("UTC"), + "pacific": get_series_with_tz("US/Pacific"), + "tokyo": get_series_with_tz("Asia/Tokyo"), + } - # expected to_pandas dataframe's timezone is controlled by session/statement parameter TIMEZONE - expected_to_pandas = native_pd.DataFrame( - { - series: test_data_columns[series].dt.tz_convert(timezone) - for series in test_data_columns - } - ) - assert_snowpark_pandas_equal_to_pandas( - pd.DataFrame(test_data_columns), - expected_to_pandas, - # configure different timezones to to_pandas and verify the timestamps are converted correctly - statement_params={"timezone": timezone}, - ) - finally: - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - session.sql("alter session unset timezone").collect() + expected_data_columns = { + "utc": get_series_with_tz("UTC"), + "pacific": get_series_with_tz("UTC-08:00"), + "tokyo": get_series_with_tz("UTC+09:00"), + } + # expected to_pandas dataframe's timezone is controlled by session/statement parameter TIMEZONE + expected_to_pandas = native_pd.DataFrame(expected_data_columns) + assert_snowpark_pandas_equal_to_pandas( + pd.DataFrame(test_data_columns), + expected_to_pandas, + ) -@pytest.mark.parametrize("timezone", ["UTC", "US/Pacific", "US/Eastern"]) -@sql_count_checker(query_count=3) -def test_from_to_pandas_datetime64_multi_timezone_current_behavior(session, timezone): - try: - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - session.sql(f"alter session set timezone = '{timezone}'").collect() - - # This test also verifies the current behaviors of to_pandas() for datetime with no tz, same tz, or multi tz: - # no tz => TIMESTAMP_NTZ - # same tz => TIMESTAMP_TZ - # multi tz => TIMESTAMP_NTZ - multi_tz_data = ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-07:00"] - test_data_columns = { - "no tz": native_pd.to_datetime( - native_pd.Series(["2019-05-21 12:00:00", "2019-05-21 12:15:00"]) - ), # dtype = datetime64[ns] - "same tz": native_pd.to_datetime( - native_pd.Series( - ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-06:00"] - ) - ), # dtype = datetime64[ns, tz] - "multi tz": native_pd.to_datetime( - native_pd.Series(multi_tz_data) - ), # dtype = object and value type is Python datetime - } +@sql_count_checker(query_count=1) +def test_from_to_pandas_datetime64_multi_timezone_current_behavior(): + # This test also verifies the current behaviors of to_pandas() for datetime with no tz, same tz, or multi tz: + # no tz => TIMESTAMP_NTZ + # same tz => TIMESTAMP_TZ + # multi tz => TIMESTAMP_TZ + multi_tz_data = ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-07:00"] + test_data_columns = { + "no tz": native_pd.to_datetime( + native_pd.Series(["2019-05-21 12:00:00", "2019-05-21 12:15:00"]) + ), # dtype = datetime64[ns] + "same tz": native_pd.to_datetime( + native_pd.Series(["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-06:00"]) + ), # dtype = datetime64[ns, tz] + "multi tz": native_pd.to_datetime( + native_pd.Series(multi_tz_data) + ), # dtype = object and value type is Python datetime + } - expected_to_pandas = native_pd.DataFrame( - { - "no tz": test_data_columns["no tz"], # dtype = datetime64[ns] - "same tz": test_data_columns["same tz"].dt.tz_convert( - timezone - ), # dtype = datetime64[ns, tz] - "multi tz": native_pd.Series( - [ - native_pd.to_datetime(t).tz_convert(timezone) - for t in multi_tz_data - ] - ), - } - ) + expected_to_pandas = native_pd.DataFrame(test_data_columns) - test_df = native_pd.DataFrame(test_data_columns) - # dtype checks for each series - no_tz_dtype = test_df.dtypes["no tz"] - assert is_datetime64_any_dtype(no_tz_dtype) and not isinstance( - no_tz_dtype, DatetimeTZDtype - ) - same_tz_dtype = test_df.dtypes["same tz"] - assert is_datetime64_any_dtype(same_tz_dtype) and isinstance( - same_tz_dtype, DatetimeTZDtype - ) - multi_tz_dtype = test_df.dtypes["multi tz"] - assert ( - not is_datetime64_any_dtype(multi_tz_dtype) - and not isinstance(multi_tz_dtype, DatetimeTZDtype) - and str(multi_tz_dtype) == "object" - ) - # sample value - assert isinstance(test_df["multi tz"][0], datetime.datetime) - assert test_df["multi tz"][0].tzinfo is not None - assert_snowpark_pandas_equal_to_pandas( - pd.DataFrame(test_df), - expected_to_pandas, - statement_params={"timezone": timezone}, - ) - finally: - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - session.sql("alter session unset timezone").collect() + test_df = native_pd.DataFrame(test_data_columns) + # dtype checks for each series + no_tz_dtype = test_df.dtypes["no tz"] + assert is_datetime64_any_dtype(no_tz_dtype) and not isinstance( + no_tz_dtype, DatetimeTZDtype + ) + same_tz_dtype = test_df.dtypes["same tz"] + assert is_datetime64_any_dtype(same_tz_dtype) and isinstance( + same_tz_dtype, DatetimeTZDtype + ) + multi_tz_dtype = test_df.dtypes["multi tz"] + assert ( + not is_datetime64_any_dtype(multi_tz_dtype) + and not isinstance(multi_tz_dtype, DatetimeTZDtype) + and str(multi_tz_dtype) == "object" + ) + # sample value + assert isinstance(test_df["multi tz"][0], datetime.datetime) + assert test_df["multi tz"][0].tzinfo is not None + assert_snowpark_pandas_equal_to_pandas( + pd.DataFrame(test_df), + expected_to_pandas, + ) @sql_count_checker(query_count=1) diff --git a/tests/integ/modin/test_merge_asof.py b/tests/integ/modin/test_merge_asof.py index 681d339da90..51dda7889e7 100644 --- a/tests/integ/modin/test_merge_asof.py +++ b/tests/integ/modin/test_merge_asof.py @@ -105,6 +105,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.072"), pd.Timestamp("2016-05-25 13:30:00.075"), ], + "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], } @@ -118,6 +119,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.048"), pd.Timestamp("2016-05-25 13:30:00.048"), ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], "price": [51.95, 51.95, 720.77, 720.92, 98.0], "quantity": [75, 155, 100, 100, 100], } @@ -229,14 +231,39 @@ def test_merge_asof_left_right_on( assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) +@pytest.mark.parametrize("by", ["ticker", ["ticker"]]) @sql_count_checker(query_count=1, join_count=1) -def test_merge_asof_timestamps(left_right_timestamp_data): +def test_merge_asof_by(left_right_timestamp_data, by): left_native_df, right_native_df = left_right_timestamp_data left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by=by + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by=by) + assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) + + +@pytest.mark.parametrize( + "left_by, right_by", + [ + ("ticker", "ticker"), + (["ticker", "bid"], ["ticker", "price"]), + ], +) +@sql_count_checker(query_count=1, join_count=1) +def test_merge_asof_left_right_by(left_right_timestamp_data, left_by, right_by): + left_native_df, right_native_df = left_right_timestamp_data + left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( + right_native_df + ) + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", left_by=left_by, right_by=right_by + ) + snow_output = pd.merge_asof( + left_snow_df, right_snow_df, on="time", left_by=left_by, right_by=right_by + ) assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -248,8 +275,10 @@ def test_merge_asof_date(left_right_timestamp_data): left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by="ticker" + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by="ticker") assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -360,9 +389,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): with pytest.raises( NotImplementedError, match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ), ): pd.merge_asof( @@ -372,19 +399,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" - ), - ): - pd.merge_asof( - left_snow_df, right_snow_df, on="time", left_by="price", right_by="quantity" - ) - with pytest.raises( - NotImplementedError, - match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof(left_snow_df, right_snow_df, left_index=True, right_index=True) @@ -392,8 +407,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof( @@ -406,8 +420,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof( diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py index ce9e1caf328..a36298af251 100644 --- a/tests/integ/modin/test_telemetry.py +++ b/tests/integ/modin/test_telemetry.py @@ -110,7 +110,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name): df1_expected_api_calls = [ {"name": "TestClass.test_func"}, - {"name": "DataFrame.DataFrame.dropna", "argument": ["inplace"]}, + {"name": "DataFrame.dropna", "argument": ["inplace"]}, ] assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls @@ -121,7 +121,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name): assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls df2_expected_api_calls = df1_expected_api_calls + [ { - "name": "DataFrame.DataFrame.dropna", + "name": "DataFrame.dropna", }, ] assert df2._query_compiler.snowpark_pandas_api_calls == df2_expected_api_calls @@ -336,10 +336,7 @@ def test_telemetry_with_update_inplace(): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) df.insert(1, "newcol", [99, 99, 90]) assert len(df._query_compiler.snowpark_pandas_api_calls) == 1 - assert ( - df._query_compiler.snowpark_pandas_api_calls[0]["name"] - == "DataFrame.DataFrame.insert" - ) + assert df._query_compiler.snowpark_pandas_api_calls[0]["name"] == "DataFrame.insert" @sql_count_checker(query_count=1) @@ -403,8 +400,8 @@ def test_telemetry_getitem_setitem(): df["a"] = 0 df["b"] = 0 assert df._query_compiler.snowpark_pandas_api_calls == [ - {"name": "DataFrame.DataFrame.__setitem__"}, - {"name": "DataFrame.DataFrame.__setitem__"}, + {"name": "DataFrame.__setitem__"}, + {"name": "DataFrame.__setitem__"}, ] # Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction. s._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch() @@ -422,13 +419,17 @@ def test_telemetry_getitem_setitem(): @pytest.mark.parametrize( - "name, method, expected_query_count", + "name, expected_func_name, method, expected_query_count", [ - ["__repr__", lambda df: df.__repr__(), 1], - ["__iter__", lambda df: df.__iter__(), 0], + # __repr__ is an extension method, so the class name is shown only once. + ["__repr__", "DataFrame.__repr__", lambda df: df.__repr__(), 1], + # __iter__ was defined on the DataFrame class, so it is shown twice. + ["__iter__", "DataFrame.DataFrame.__iter__", lambda df: df.__iter__(), 0], ], ) -def test_telemetry_private_method(name, method, expected_query_count): +def test_telemetry_private_method( + name, expected_func_name, method, expected_query_count +): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) # Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction. df._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch() @@ -439,10 +440,10 @@ def test_telemetry_private_method(name, method, expected_query_count): # the telemetry log from the connector to validate data = _extract_snowpark_pandas_telemetry_log_data( - expected_func_name=f"DataFrame.DataFrame.{name}", + expected_func_name=expected_func_name, session=df._query_compiler._modin_frame.ordered_dataframe.session, ) - assert data["api_calls"] == [{"name": f"DataFrame.DataFrame.{name}"}] + assert data["api_calls"] == [{"name": expected_func_name}] @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/tools/test_to_datetime.py b/tests/integ/modin/tools/test_to_datetime.py index 1ea3445d15a..df11e6afb80 100644 --- a/tests/integ/modin/tools/test_to_datetime.py +++ b/tests/integ/modin/tools/test_to_datetime.py @@ -565,7 +565,7 @@ def test_to_datetime_mixed_datetime_and_string(self): assert_index_equal(res, expected) # Set utc=True to make sure timezone aware in to_datetime res = to_datetime(pd.Index(["2020-01-01 17:00:00 -0100", d2]), utc=True) - expected = pd.DatetimeIndex([d1, d2]) + expected = pd.DatetimeIndex([d1, d2], tz="UTC") assert_index_equal(res, expected) @pytest.mark.parametrize( diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py index 4c72df42bba..d28362374ce 100644 --- a/tests/integ/modin/types/test_timedelta.py +++ b/tests/integ/modin/types/test_timedelta.py @@ -2,10 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import datetime +import warnings import modin.pandas as pd import pandas as native_pd import pytest +from pandas.errors import SettingWithCopyWarning from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( @@ -107,3 +109,10 @@ def test_timedelta_not_supported(): match="SnowflakeQueryCompiler::groupby_groups is not yet implemented for Timedelta Type", ): df.groupby("a").groups() + + +@sql_count_checker(query_count=1) +def test_aggregation_does_not_print_internal_warning_SNOW_1664064(): + with warnings.catch_warnings(): + warnings.simplefilter(category=SettingWithCopyWarning, action="error") + pd.Series(pd.Timedelta(1)).max() diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index e42a504a976..bdd780ea69e 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -9,9 +9,13 @@ import pytest from snowflake.snowpark._internal.analyzer import analyzer -from snowflake.snowpark._internal.compiler import large_query_breakdown from snowflake.snowpark.functions import col, lit, sum_distinct, when_matched from snowflake.snowpark.row import Row +from snowflake.snowpark.session import ( + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + Session, +) from tests.utils import Utils pytestmark = [ @@ -22,9 +26,6 @@ ) ] -DEFAULT_LOWER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND -DEFAULT_UPPER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND - @pytest.fixture(autouse=True) def large_query_df(session): @@ -50,20 +51,24 @@ def setup(session): is_query_compilation_stage_enabled = session._query_compilation_stage_enabled session._query_compilation_stage_enabled = True session._large_query_breakdown_enabled = True + set_bounds(session, 300, 600) yield session._query_compilation_stage_enabled = is_query_compilation_stage_enabled session._cte_optimization_enabled = cte_optimization_enabled session._large_query_breakdown_enabled = large_query_breakdown_enabled - reset_bounds() + reset_bounds(session) -def set_bounds(lower_bound: int, upper_bound: int): - large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND = lower_bound - large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND = upper_bound +def set_bounds(session: Session, lower_bound: int, upper_bound: int): + session._large_query_breakdown_complexity_bounds = (lower_bound, upper_bound) -def reset_bounds(): - set_bounds(DEFAULT_LOWER_BOUND, DEFAULT_UPPER_BOUND) +def reset_bounds(session: Session): + set_bounds( + session, + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + ) def check_result_with_and_without_breakdown(session, df): @@ -82,8 +87,6 @@ def check_result_with_and_without_breakdown(session, df): def test_no_valid_nodes_found(session, large_query_df, caplog): """Test large query breakdown works with default bounds""" - set_bounds(300, 600) - base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -104,7 +107,6 @@ def test_no_valid_nodes_found(session, large_query_df, caplog): def test_large_query_breakdown_with_cte_optimization(session): """Test large query breakdown works with cte optimized plan""" - set_bounds(300, 600) session._cte_optimization_enabled = True df0 = session.sql("select 2 as b, 32 as c") df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1) @@ -131,7 +133,6 @@ def test_large_query_breakdown_with_cte_optimization(session): def test_save_as_table(session, large_query_df): - set_bounds(300, 600) table_name = Utils.random_table_name() with session.query_history() as history: large_query_df.write.save_as_table(table_name, mode="overwrite") @@ -146,7 +147,6 @@ def test_save_as_table(session, large_query_df): def test_update_delete_merge(session, large_query_df): - set_bounds(300, 600) session._large_query_breakdown_enabled = True table_name = Utils.random_table_name() df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"]) @@ -186,7 +186,6 @@ def test_update_delete_merge(session, large_query_df): def test_copy_into_location(session, large_query_df): - set_bounds(300, 600) remote_file_path = f"{session.get_session_stage()}/df.parquet" with session.query_history() as history: large_query_df.write.copy_into_location( @@ -204,7 +203,6 @@ def test_copy_into_location(session, large_query_df): def test_pivot_unpivot(session): - set_bounds(300, 600) session.sql( """create or replace temp table monthly_sales(A int, B int, month text) as select * from values @@ -243,7 +241,6 @@ def test_pivot_unpivot(session): def test_sort(session): - set_bounds(300, 600) base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -276,7 +273,6 @@ def test_sort(session): def test_multiple_query_plan(session, large_query_df): - set_bounds(300, 600) original_threshold = analyzer.ARRAY_BIND_THRESHOLD try: analyzer.ARRAY_BIND_THRESHOLD = 2 @@ -314,7 +310,6 @@ def test_multiple_query_plan(session, large_query_df): def test_optimization_skipped_with_transaction(session, large_query_df, caplog): """Test large query breakdown is skipped when transaction is enabled""" - set_bounds(300, 600) session.sql("begin").collect() assert Utils.is_active_transaction(session) with caplog.at_level(logging.DEBUG): @@ -330,7 +325,6 @@ def test_optimization_skipped_with_transaction(session, large_query_df, caplog): def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): """Test large query breakdown is skipped plan is a view or dynamic table""" - set_bounds(300, 600) source_table = Utils.random_table_name() table_name = Utils.random_table_name() view_name = Utils.random_view_name() @@ -360,7 +354,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): def test_async_job_with_large_query_breakdown(session, large_query_df): """Test large query breakdown gives same result for async and non-async jobs""" - set_bounds(300, 600) job = large_query_df.collect(block=False) result = job.result() assert result == large_query_df.collect() @@ -376,8 +369,6 @@ def test_async_job_with_large_query_breakdown(session, large_query_df): def test_add_parent_plan_uuid_to_statement_params(session, large_query_df): - set_bounds(300, 600) - with patch.object( session._conn, "run_query", wraps=session._conn.run_query ) as patched_run_query: @@ -400,7 +391,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): """Test complexity bounds affect number of partitions. Also test that when partitions are added, drop table queries are added. """ - set_bounds(300, 600) + set_bounds(session, 300, 600) assert len(large_query_df.queries["queries"]) == 2 assert len(large_query_df.queries["post_actions"]) == 1 assert large_query_df.queries["queries"][0].startswith( @@ -410,7 +401,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): "DROP TABLE If EXISTS" ) - set_bounds(300, 412) + set_bounds(session, 300, 412) assert len(large_query_df.queries["queries"]) == 3 assert len(large_query_df.queries["post_actions"]) == 2 assert large_query_df.queries["queries"][0].startswith( @@ -426,11 +417,11 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): "DROP TABLE If EXISTS" ) - set_bounds(0, 300) + set_bounds(session, 0, 300) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 - reset_bounds() + reset_bounds(session) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 0e8bb0d902d..81b852c46c1 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -98,6 +98,24 @@ def test_range_statement(session: Session): ) +def test_literal_complexity_for_snowflake_values(session: Session): + from snowflake.snowpark._internal.analyzer import analyzer + + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + assert_df_subtree_query_complexity( + df1, {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LITERAL: 4} + ) + + try: + original_threshold = analyzer.ARRAY_BIND_THRESHOLD + analyzer.ARRAY_BIND_THRESHOLD = 2 + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + # SELECT "A", "B" from (SELECT * FROM TEMP_TABLE) + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 3}) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + + def test_generator_table_function(session: Session): df1 = session.generator( seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index df0afc1099b..21e77883338 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -5,6 +5,7 @@ import os from functools import partial +from unittest.mock import patch import pytest @@ -719,6 +720,31 @@ def test_eliminate_numeric_sql_value_cast_optimization_enabled_on_session( new_session.eliminate_numeric_sql_value_cast_enabled = None +def test_large_query_breakdown_complexity_bounds(session): + original_bounds = session.large_query_breakdown_complexity_bounds + try: + with pytest.raises(ValueError, match="Expecting a tuple of two integers"): + session.large_query_breakdown_complexity_bounds = (1, 2, 3) + + with pytest.raises( + ValueError, match="Expecting a tuple of lower and upper bound" + ): + session.large_query_breakdown_complexity_bounds = (3, 2) + + with patch.object( + session._conn._telemetry_client, + "send_large_query_breakdown_update_complexity_bounds", + ) as patch_send: + session.large_query_breakdown_complexity_bounds = (1, 2) + assert session.large_query_breakdown_complexity_bounds == (1, 2) + assert patch_send.call_count == 1 + assert patch_send.call_args[0][0] == session.session_id + assert patch_send.call_args[0][1] == 1 + assert patch_send.call_args[0][2] == 2 + finally: + session.large_query_breakdown_complexity_bounds = original_bounds + + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_create_session_from_default_config_file(monkeypatch, db_parameters): import tomlkit diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index bcfa2cfa512..39749de76f6 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -5,6 +5,7 @@ import decimal import sys +import uuid from functools import partial from typing import Any, Dict, Tuple @@ -599,6 +600,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -621,6 +623,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -643,6 +646,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -665,6 +669,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -687,6 +692,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -829,10 +835,15 @@ def test_dataframe_stat_functions_api_calls(session): column = 6 if session.sql_simplifier_enabled else 9 crosstab = df.stat.crosstab("empid", "month") + # uuid here is generated by an intermediate dataframe in crosstab implementation + # therefore we can't predict it. We check that the uuid for crosstab is same as + # that for df. + uuid = df._plan.api_calls[0]["plan_uuid"] assert crosstab._plan.api_calls == [ { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": uuid, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": {"group_by": 1, "column": column, "literal": 48}, @@ -851,6 +862,7 @@ def test_dataframe_stat_functions_api_calls(session): { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": uuid, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": {"group_by": 1, "column": column, "literal": 48}, @@ -1166,3 +1178,96 @@ def send_large_query_optimization_skipped_telemetry(): ) assert data == expected_data assert type_ == "snowpark_large_query_breakdown_optimization_skipped" + + +def test_post_compilation_stage_telemetry(session): + client = session._conn._telemetry_client + uuid_str = str(uuid.uuid4()) + + def send_telemetry(): + summary_value = { + "cte_optimization_enabled": True, + "large_query_breakdown_enabled": True, + "complexity_score_bounds": (300, 600), + "time_taken_for_compilation": 0.136, + "time_taken_for_deep_copy_plan": 0.074, + "time_taken_for_cte_optimization": 0.01, + "time_taken_for_large_query_breakdown": 0.062, + "complexity_score_before_compilation": 1148, + "complexity_score_after_cte_optimization": [1148], + "complexity_score_after_large_query_breakdown": [514, 636], + } + client.send_query_compilation_summary_telemetry( + session_id=session.session_id, + plan_uuid=uuid_str, + compilation_stage_summary=summary_value, + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "plan_uuid": uuid_str, + "cte_optimization_enabled": True, + "large_query_breakdown_enabled": True, + "complexity_score_bounds": (300, 600), + "time_taken_for_compilation": 0.136, + "time_taken_for_deep_copy_plan": 0.074, + "time_taken_for_cte_optimization": 0.01, + "time_taken_for_large_query_breakdown": 0.062, + "complexity_score_before_compilation": 1148, + "complexity_score_after_cte_optimization": [1148], + "complexity_score_after_large_query_breakdown": [514, 636], + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_compilation_stage_statistics" + + +def test_temp_table_cleanup(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_telemetry( + session.session_id, + temp_table_cleaner_enabled=True, + num_temp_tables_cleaned=2, + num_temp_tables_created=5, + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleaner_enabled": True, + "num_temp_tables_cleaned": 2, + "num_temp_tables_created": 5, + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup" + + +def test_temp_table_cleanup_exception(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_abnormal_exception_telemetry( + session.session_id, + table_name="table_name_placeholder", + exception_message="exception_message_placeholder", + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleanup_abnormal_exception_table_name": "table_name_placeholder", + "temp_table_cleanup_abnormal_exception_message": "exception_message_placeholder", + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup_abnormal_exception" diff --git a/tests/integ/test_temp_table_cleanup.py b/tests/integ/test_temp_table_cleanup.py index 4ac87661484..cdd97d49937 100644 --- a/tests/integ/test_temp_table_cleanup.py +++ b/tests/integ/test_temp_table_cleanup.py @@ -12,6 +12,7 @@ from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, + warning_dict, ) from snowflake.snowpark.functions import col from tests.utils import IS_IN_STORED_PROC @@ -25,40 +26,61 @@ WAIT_TIME = 1 +@pytest.fixture(autouse=True) +def setup(session): + auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled + session.auto_clean_up_temp_table_enabled = True + yield + session.auto_clean_up_temp_table_enabled = auto_clean_up_temp_table_enabled + + def test_basic(session): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = df1.select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df3 = df1.union_all(df2) df3.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df2 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df3 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 def test_function(session): + session._temp_table_auto_cleaner.ref_count_map.clear() table_name = None def f(session: Session) -> None: @@ -68,13 +90,16 @@ def f(session: Session) -> None: nonlocal table_name table_name = df.table_name assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() f(session) gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_name.split(".")) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.parametrize( @@ -86,33 +111,42 @@ def f(session: Session) -> None: ], ) def test_copy(session, copy_function): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = copy_function(df1).select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 2 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_reference_count_map_multiple_sessions(db_parameters, session): + session._temp_table_auto_cleaner.ref_count_map.clear() new_session = Session.builder.configs(db_parameters).create() + new_session.auto_clean_up_temp_table_enabled = True try: df1 = session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] @@ -120,43 +154,59 @@ def test_reference_count_map_multiple_sessions(db_parameters, session): table_name1 = df1.table_name table_ids1 = table_name1.split(".") assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 1 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = new_session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).cache_result() table_name2 = df2.table_name table_ids2 = table_name2.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids1) assert new_session._table_exists(table_ids2) assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - new_session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not new_session._table_exists(table_ids2) - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 finally: new_session.close() def test_save_as_table_no_drop(session): - session._temp_table_auto_cleaner.start() + session._temp_table_auto_cleaner.ref_count_map.clear() def f(session: Session, temp_table_name: str) -> None: session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).write.save_as_table(temp_table_name, table_type="temp") - assert session._temp_table_auto_cleaner.ref_count_map[temp_table_name] == 0 + assert temp_table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) f(session, temp_table_name) @@ -165,34 +215,25 @@ def f(session: Session, temp_table_name: str) -> None: assert session._table_exists([temp_table_name]) -def test_start_stop(session): - session._temp_table_auto_cleaner.stop() - - df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() - table_name = df1.table_name +def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): + warning_dict.clear() + with caplog.at_level(logging.WARNING): + session.auto_clean_up_temp_table_enabled = False + assert session.auto_clean_up_temp_table_enabled is False + assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() + table_name = df.table_name table_ids = table_name.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 - del df1 + del df gc.collect() - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 - assert not session._temp_table_auto_cleaner.queue.empty() - assert session._table_exists(table_ids) - - session._temp_table_auto_cleaner.start() time.sleep(WAIT_TIME) - assert session._temp_table_auto_cleaner.queue.empty() - assert not session._table_exists(table_ids) - - -def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): - with caplog.at_level(logging.WARNING): - session.auto_clean_up_temp_table_enabled = True + assert session._table_exists(table_ids) + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + session.auto_clean_up_temp_table_enabled = True assert session.auto_clean_up_temp_table_enabled is True - assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text - assert session._temp_table_auto_cleaner.is_alive() - session.auto_clean_up_temp_table_enabled = False - assert session.auto_clean_up_temp_table_enabled is False - assert not session._temp_table_auto_cleaner.is_alive() + with pytest.raises( ValueError, match="value for auto_clean_up_temp_table_enabled must be True or False!", diff --git a/tests/mock/test_multithreading.py b/tests/mock/test_multithreading.py new file mode 100644 index 00000000000..5e0078212d6 --- /dev/null +++ b/tests/mock/test_multithreading.py @@ -0,0 +1,335 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import io +import json +import os +import tempfile +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Thread + +import pytest + +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + LogicalPlan, + SaveMode, +) +from snowflake.snowpark._internal.utils import normalize_local_file +from snowflake.snowpark.functions import lit, when_matched +from snowflake.snowpark.mock._connection import MockServerConnection +from snowflake.snowpark.mock._functions import MockedFunctionRegistry +from snowflake.snowpark.mock._plan import MockExecutionPlan +from snowflake.snowpark.mock._snowflake_data_type import TableEmulator +from snowflake.snowpark.mock._stage_registry import StageEntityRegistry +from snowflake.snowpark.mock._telemetry import LocalTestOOBTelemetryService +from snowflake.snowpark.row import Row +from snowflake.snowpark.session import Session +from tests.utils import Utils + + +def test_table_update_merge_delete(session): + table_name = Utils.random_table_name() + num_threads = 10 + data = [[v, 11 * v] for v in range(10)] + df = session.create_dataframe(data, schema=["a", "b"]) + df.write.save_as_table(table_name, table_type="temp") + + source_df = df + t = session.table(table_name) + + def update_table(thread_id: int): + t.update({"b": 0}, t.a == lit(thread_id)) + + def merge_table(thread_id: int): + t.merge( + source_df, t.a == source_df.a, [when_matched().update({"b": source_df.b})] + ) + + def delete_table(thread_id: int): + t.delete(t.a == lit(thread_id)) + + # all threads will update column b to 0 where a = thread_id + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # update + futures = [executor.submit(update_table, i) for i in range(num_threads)] + for future in as_completed(futures): + future.result() + + # all threads will set column b to 0 + Utils.check_answer(t.select(t.b), [Row(B=0) for _ in range(10)]) + + # merge + futures = [executor.submit(merge_table, i) for i in range(num_threads)] + for future in as_completed(futures): + future.result() + + # all threads will set column b to 11 * a + Utils.check_answer(t.select(t.b), [Row(B=11 * i) for i in range(10)]) + + # delete + futures = [executor.submit(delete_table, i) for i in range(num_threads)] + for future in as_completed(futures): + future.result() + + # all threads will delete their row + assert t.count() == 0 + + +def test_udf_register_and_invoke(session): + df = session.create_dataframe([[1], [2]], schema=["num"]) + num_threads = 10 + + def register_udf(x: int): + def echo(x: int) -> int: + return x + + return session.udf.register(echo, name="echo", replace=True) + + def invoke_udf(): + result = df.select(session.udf.call_udf("echo", df.num)).collect() + assert result[0][0] == 1 + assert result[1][0] == 2 + + threads = [] + for i in range(num_threads): + thread_register = Thread(target=register_udf, args=(i,)) + threads.append(thread_register) + thread_register.start() + + thread_invoke = Thread(target=invoke_udf) + threads.append(thread_invoke) + thread_invoke.start() + + for thread in threads: + thread.join() + + +def test_sp_register_and_invoke(session): + num_threads = 10 + + def increment_by_one_fn(session_: Session, x: int) -> int: + return x + 1 + + def register_sproc(): + session.sproc.register( + increment_by_one_fn, name="increment_by_one", replace=True + ) + + def invoke_sproc(): + result = session.call("increment_by_one", 1) + assert result == 2 + + threads = [] + for i in range(num_threads): + thread_register = Thread(target=register_sproc, args=(i,)) + threads.append(thread_register) + thread_register.start() + + thread_invoke = Thread(target=invoke_sproc) + threads.append(thread_invoke) + thread_invoke.start() + + for thread in threads: + thread.join() + + +def test_mocked_function_registry_created_once(): + num_threads = 10 + + result = [] + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(MockedFunctionRegistry.get_or_create) + for _ in range(num_threads) + ] + + for future in as_completed(futures): + result.append(future.result()) + + registry = MockedFunctionRegistry.get_or_create() + assert all([registry is r for r in result]) + + +@pytest.mark.parametrize("test_table", [True, False]) +def test_tabular_entity_registry(test_table): + conn = MockServerConnection() + entity_registry = conn.entity_registry + num_threads = 10 + + def write_read_and_drop_table(): + table_name = Utils.random_table_name() + table_emulator = TableEmulator() + + entity_registry.write_table(table_name, table_emulator, SaveMode.OVERWRITE) + + optional_table = entity_registry.read_table_if_exists(table_name) + if optional_table is not None: + assert optional_table.empty + + entity_registry.drop_table(table_name) + + def write_read_and_drop_view(): + view_name = Utils.random_view_name() + empty_logical_plan = LogicalPlan() + plan = MockExecutionPlan(empty_logical_plan, None) + + entity_registry.create_or_replace_view(plan, view_name) + + optional_view = entity_registry.read_view_if_exists(view_name) + if optional_view is not None: + assert optional_view.source_plan == empty_logical_plan + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + if test_table: + test_fn = write_read_and_drop_table + else: + test_fn = write_read_and_drop_view + futures = [executor.submit(test_fn) for _ in range(num_threads)] + + for future in as_completed(futures): + future.result() + + +def test_stage_entity_registry_put_and_get(): + stage_registry = StageEntityRegistry(MockServerConnection()) + num_threads = 10 + + def put_and_get_file(): + stage_registry.put( + normalize_local_file( + f"{os.path.dirname(os.path.abspath(__file__))}/files/test_file_1" + ), + "@test_stage/test_parent_dir/test_child_dir", + ) + with tempfile.TemporaryDirectory() as temp_dir: + stage_registry.get( + "@test_stage/test_parent_dir/test_child_dir/test_file_1", + temp_dir, + ) + assert os.path.isfile(os.path.join(temp_dir, "test_file_1")) + + threads = [] + for _ in range(num_threads): + thread = Thread(target=put_and_get_file) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + +def test_stage_entity_registry_upload_and_read(session): + stage_registry = StageEntityRegistry(MockServerConnection()) + num_threads = 10 + + def upload_and_read_json(thread_id: int): + json_string = json.dumps({"thread_id": thread_id}) + bytes_io = io.BytesIO(json_string.encode("utf-8")) + stage_registry.upload_stream( + input_stream=bytes_io, + stage_location="@test_stage/test_parent_dir", + file_name=f"test_file_{thread_id}", + ) + + df = stage_registry.read_file( + f"@test_stage/test_parent_dir/test_file_{thread_id}", + "json", + [], + session._analyzer, + {"INFER_SCHEMA": "True"}, + ) + + assert df['"thread_id"'].iloc[0] == thread_id + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(upload_and_read_json, i) for i in range(num_threads)] + + for future in as_completed(futures): + future.result() + + +def test_stage_entity_registry_create_or_replace(): + stage_registry = StageEntityRegistry(MockServerConnection()) + num_threads = 10 + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(stage_registry.create_or_replace_stage, f"test_stage_{i}") + for i in range(num_threads) + ] + + for future in as_completed(futures): + future.result() + + assert len(stage_registry._stage_registry) == num_threads + for i in range(num_threads): + assert f"test_stage_{i}" in stage_registry._stage_registry + + +def test_oob_telemetry_add(): + oob_service = LocalTestOOBTelemetryService.get_instance() + # clean up queue first + oob_service.export_queue_to_string() + num_threads = 10 + num_events_per_thread = 10 + + # create a function that adds 10 events to the queue + def add_events(thread_id: int): + for i in range(num_events_per_thread): + oob_service.add( + {f"thread_{thread_id}_event_{i}": f"dummy_event_{thread_id}_{i}"} + ) + + # set batch_size to 101 + is_enabled = oob_service.enabled + oob_service.enable() + original_batch_size = oob_service.batch_size + oob_service.batch_size = num_threads * num_events_per_thread + 1 + try: + # create 10 threads + threads = [] + for thread_id in range(num_threads): + thread = Thread(target=add_events, args=(thread_id,)) + threads.append(thread) + thread.start() + + # wait for all threads to finish + for thread in threads: + thread.join() + + # assert that the queue size is 100 + assert oob_service.queue.qsize() == num_threads * num_events_per_thread + finally: + oob_service.batch_size = original_batch_size + if not is_enabled: + oob_service.disable() + + +def test_oob_telemetry_flush(): + oob_service = LocalTestOOBTelemetryService.get_instance() + # clean up queue first + oob_service.export_queue_to_string() + + is_enabled = oob_service.enabled + oob_service.enable() + # add a dummy event + oob_service.add({"event": "dummy_event"}) + + try: + # flush the queue in multiple threads + num_threads = 10 + threads = [] + for _ in range(num_threads): + thread = Thread(target=oob_service.flush) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # assert that the queue is empty + assert oob_service.size() == 0 + finally: + if not is_enabled: + oob_service.disable() diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index 7c5e3a40bb0..d94c80b8d67 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -166,6 +166,7 @@ def test_overrides(self): # Test for pandas doc when function is not defined on module. assert pandas.read_table.__doc__ in pd.read_table.__doc__ + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_not_redefining_classes_modin_issue_7138(self): original_dataframe_class = pd.DataFrame _init_doc_module() diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index 5434387ba71..6c9edfd024f 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -2,12 +2,20 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from types import MappingProxyType +from unittest import mock + import numpy as np import pytest +import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils +from snowflake.snowpark.functions import greatest, sum as sum_ from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + SnowflakeAggFunc, + _is_supported_snowflake_agg_func, + _SnowparkPandasAggregation, check_is_aggregation_supported_in_snowflake, - is_supported_snowflake_agg_func, + get_snowflake_agg_func, ) @@ -53,8 +61,8 @@ ("quantile", {}, 1, False), ], ) -def test_is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: - assert is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid +def test__is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: + assert _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid @pytest.mark.parametrize( @@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args( agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0 ) assert can_be_distributed == expected_result + + +@pytest.mark.parametrize( + "agg_func, agg_kwargs, axis, expected", + [ + (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)), + ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)), + ("test", {}, 0, None), + ], +) +def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected): + result = get_snowflake_agg_func(agg_func, agg_kwargs, axis) + if expected is None: + assert result is None + else: + assert result == expected + + +def test_get_snowflake_agg_func_with_no_implementation_on_axis_0(): + """Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0.""" + # We have to patch the internal dictionary + # _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is + # no real function that we support on axis=1 but not on axis=0. + with mock.patch.object( + aggregation_utils, + "_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION", + MappingProxyType( + { + "max": _SnowparkPandasAggregation( + preserves_snowpark_pandas_types=True, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=greatest, + ) + } + ), + ): + assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None diff --git a/tests/unit/modin/test_series_dt.py b/tests/unit/modin/test_series_dt.py index be0039683a8..0b5572f0592 100644 --- a/tests/unit/modin/test_series_dt.py +++ b/tests/unit/modin/test_series_dt.py @@ -32,8 +32,6 @@ def mock_query_compiler_for_dt_series() -> SnowflakeQueryCompiler: [ (lambda s: s.dt.timetz, "timetz"), (lambda s: s.dt.to_period(), "to_period"), - (lambda s: s.dt.tz_localize(tz="UTC"), "tz_localize"), - (lambda s: s.dt.tz_convert(tz="UTC"), "tz_convert"), (lambda s: s.dt.strftime(date_format="YY/MM/DD"), "strftime"), (lambda s: s.dt.qyear, "qyear"), (lambda s: s.dt.start_time, "start_time"), diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index c31e5cc6290..c9b8a1ce38d 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -87,30 +87,37 @@ def test_expression(): a = Expression() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] b = Expression(child=UnresolvedAttribute("a")) assert b.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert b.dependent_column_names_with_duplication() == [] # root class Expression always returns empty dependency def test_literal(): a = Literal(5) assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] def test_attribute(): a = Attribute("A", IntegerType()) assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] def test_unresolved_attribute(): a = UnresolvedAttribute("A") assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] b = UnresolvedAttribute("a > 1", is_sql_text=True) assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] c = UnresolvedAttribute("$1 > 1", is_sql_text=True) assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert c.dependent_column_names_with_duplication() == ["$"] def test_case_when(): @@ -118,46 +125,85 @@ def test_case_when(): b = Column("b") z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e")) assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'} + # verify column '"A"', '"B"' occurred twice in the dependency columns + assert z._expression.dependent_column_names_with_duplication() == [ + '"A"', + '"B"', + '"C"', + '"A"', + '"B"', + '"D"', + '"E"', + ] def test_collate(): a = Collate(UnresolvedAttribute("a"), "spec") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_function_expression(): a = FunctionExpression("test_func", [UnresolvedAttribute(x) for x in "abcd"], False) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # expressions with duplicated dependent column + b = FunctionExpression( + "test_func", [UnresolvedAttribute(x) for x in "abcdad"], False + ) + assert b.dependent_column_names() == set("abcd") + assert b.dependent_column_names_with_duplication() == list("abcdad") def test_in_expression(): a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") def test_like(): a = Like(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = Like(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_list_agg(): a = ListAgg(UnresolvedAttribute("a"), ",", True) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_multiple_expression(): a = MultipleExpression([UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + a = MultipleExpression([UnresolvedAttribute(x) for x in "abcdbea"]) + assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("abcdbea") def test_reg_exp(): a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + b = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_scalar_subquery(): a = ScalarSubquery(None) assert a.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert a.dependent_column_names_with_duplication() == list(COLUMN_DEPENDENCY_DOLLAR) def test_snowflake_udf(): @@ -165,21 +211,42 @@ def test_snowflake_udf(): "udf_name", [UnresolvedAttribute(x) for x in "abcd"], IntegerType() ) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + b = SnowflakeUDF( + "udf_name", [UnresolvedAttribute(x) for x in "abcdfc"], IntegerType() + ) + assert b.dependent_column_names() == set("abcdf") + assert b.dependent_column_names_with_duplication() == list("abcdfc") def test_star(): a = Star([Attribute(x, IntegerType()) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + b = Star([]) + assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] def test_subfield_string(): a = SubfieldString(UnresolvedAttribute("a"), "field") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_within_group(): a = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") + + b = WithinGroup( + UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"] + ) + assert b.dependent_column_names() == set("abcde") + assert b.dependent_column_names_with_duplication() == list("eabcdea") @pytest.mark.parametrize( @@ -189,16 +256,19 @@ def test_within_group(): def test_unary_expression(expression_class): a = expression_class(child=UnresolvedAttribute("a")) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_alias(): a = Alias(child=Add(UnresolvedAttribute("a"), UnresolvedAttribute("b")), name="c") assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_cast(): a = Cast(UnresolvedAttribute("a"), IntegerType()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] @pytest.mark.parametrize( @@ -234,6 +304,19 @@ def test_binary_expression(expression_class): assert b.dependent_column_names() == {"B"} assert binary_expression.dependent_column_names() == {"A", "B"} + assert a.dependent_column_names_with_duplication() == ["A"] + assert b.dependent_column_names_with_duplication() == ["B"] + assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"] + + # hierarchical expressions with duplication + hierarchical_binary_expression = expression_class(expression_class(a, b), b) + assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"} + assert hierarchical_binary_expression.dependent_column_names_with_duplication() == [ + "A", + "B", + "B", + ] + @pytest.mark.parametrize( "expression_class", @@ -253,6 +336,18 @@ def test_grouping_set(expression_class): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] + + # with duplication + b = expression_class( + [ + UnresolvedAttribute("a"), + UnresolvedAttribute("a"), + UnresolvedAttribute("c"), + ] + ) + assert b.dependent_column_names() == {"a", "c"} + assert b.dependent_column_names_with_duplication() == ["a", "a", "c"] def test_grouping_sets_expression(): @@ -263,11 +358,13 @@ def test_grouping_sets_expression(): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] def test_sort_order(): a = SortOrder(UnresolvedAttribute("a"), Ascending()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_specified_window_frame(): @@ -275,12 +372,21 @@ def test_specified_window_frame(): RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("b") ) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("a") + ) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] @pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead]) def test_rank_related_function_expression(expression_class): a = expression_class(UnresolvedAttribute("a"), 1, UnresolvedAttribute("b"), False) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_window_spec_definition(): @@ -295,6 +401,7 @@ def test_window_spec_definition(): ), ) assert a.dependent_column_names() == set("abcdef") + assert a.dependent_column_names_with_duplication() == list("abcdef") def test_window_expression(): @@ -310,6 +417,23 @@ def test_window_expression(): ) a = WindowExpression(UnresolvedAttribute("x"), window_spec_definition) assert a.dependent_column_names() == set("abcdefx") + assert a.dependent_column_names_with_duplication() == list("xabcdef") + + +def test_window_expression_with_duplication_columns(): + window_spec_definition = WindowSpecDefinition( + [UnresolvedAttribute("a"), UnresolvedAttribute("b")], + [ + SortOrder(UnresolvedAttribute("c"), Ascending()), + SortOrder(UnresolvedAttribute("a"), Ascending()), + ], + SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("e"), UnresolvedAttribute("f") + ), + ) + a = WindowExpression(UnresolvedAttribute("e"), window_spec_definition) + assert a.dependent_column_names() == set("abcef") + assert a.dependent_column_names_with_duplication() == list("eabcaef") @pytest.mark.parametrize( @@ -325,3 +449,4 @@ def test_window_expression(): def test_other_window_expressions(expression_class): a = expression_class() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 262c9e82c44..370ee455d62 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -112,6 +112,7 @@ def test_used_scoped_temp_object(): def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._telemetry_client = mock.Mock() fake_connection.is_closed = MagicMock(return_value=False) exception_msg = "Mock exception for session.cancel_all" fake_connection.run_query = MagicMock(side_effect=Exception(exception_msg))