diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a358..90f2003251d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,14 @@ #### New Features - Added support for `TimedeltaIndex.mean` method. +- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. +- Added support for `by`, `left_by`, `right_by`, `left_index`, and `right_index` for `pd.merge_asof`. +#### Bug Fixes + +- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. +- Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns. +- Fixed `inplace` argument for `Series` objects derived from other `Series` objects. ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. @@ -124,6 +131,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added support for `Series.dt.total_seconds` method. - Added support for `DataFrame.apply(axis=0)`. - Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`. +- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`. #### Improvements diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 68b1935da96..3afe671aee7 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -82,9 +82,9 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``snap`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | | +| ``tz_convert`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | | +| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 797ef3bbd59..b3a71f023a9 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -38,9 +38,7 @@ Data manipulations +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. | -| | | , ``left_index``, ``right_index``| | -| | | , ``suffixes``, ``tolerance`` | | +| ``merge_asof`` | P | ``suffixes``, ``tolerance`` | ``N`` if param ``direction`` is ``nearest`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge_ordered`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 76e91b7da92..d8622299ea9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if ( - len(logical_plan.output) * len(logical_plan.data) - < ARRAY_BIND_THRESHOLD - ): + if not logical_plan.is_large_local_data: return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 3ed969caada..22591f55e47 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -29,6 +30,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.left, self.right) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a2d21db4eb2..a7cb5fd97a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -35,6 +35,13 @@ def derive_dependent_columns( *expressions: "Optional[Expression]", ) -> Optional[AbstractSet[str]]: + """ + Given set of expressions, derive the set of columns that the expressions dependents on. + + Note, the returned dependent columns is a set without duplication. For example, given expression + concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has + occurred in the given expression twice. + """ result = set() for exp in expressions: if exp is not None: @@ -48,6 +55,23 @@ def derive_dependent_columns( return result +def derive_dependent_columns_with_duplication( + *expressions: "Optional[Expression]", +) -> List[str]: + """ + Given set of expressions, derive the list of columns that the expression dependents on. + + Note, the returned columns will have duplication if the column occurred more than once in + the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result + [col1, col1, col2], where col1 occurred twice in the result. + """ + result = [] + for exp in expressions: + if exp is not None: + result.extend(exp.dependent_column_names_with_duplication()) + return result + + class Expression: """Consider removing attributes, and adding properties and methods. A subclass of Expression may have no child, one child, or multiple children. @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY + def dependent_column_names_with_duplication(self) -> List[str]: + return [] + @property def pretty_name(self) -> str: """Returns a user-facing string representation of this expression's name. @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + def dependent_column_names_with_duplication(self) -> List[str]: + return list(COLUMN_DEPENDENCY_DOLLAR) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.expressions) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.columns, *self.values) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN @@ -212,6 +248,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + def dependent_column_names_with_duplication(self) -> List[str]: + return [self.name] + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + derive_dependent_columns_with_duplication(*self.expressions) + if self.expressions + else [] # we currently do not handle * dependency + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} @@ -278,6 +324,14 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + [] + if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL) + or (self._dependent_column_names is None) + else list(self._dependent_column_names) + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern @@ -400,6 +457,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -510,6 +579,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) + @property def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) @@ -549,13 +624,21 @@ def __init__( self.branches = branches self.else_value = else_value - def dependent_column_names(self) -> Optional[AbstractSet[str]]: + @property + def _child_expressions(self) -> List[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) if self.else_value is not None: exps.append(self.else_value) - return derive_dependent_columns(*exps) + + return exps + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self._child_expressions) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self._child_expressions) @property def plan_node_category(self) -> PlanNodeCategory: @@ -602,6 +685,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.col) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.exprs) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 84cd63fd87d..012940471d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.group_by_exprs) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + def dependent_column_names_with_duplication(self) -> List[str]: + flattened_args = [exp for sublist in self.args for exp in sublist] + return derive_dependent_columns_with_duplication(*flattened_args) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index e3e032cd94b..aa8730dcf7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -144,10 +144,27 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def is_large_local_data(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.is_large_local_data: + # When the number of literals exceeds the threshold, we generate 3 queries: + # 1. create table query + # 2. insert into table query + # 3. select * from table query + # We only consider the complexity from the final select * query since other queries + # are built based on it. + return { + PlanNodeCategory.COLUMN: 1, + } + + # If we stay under the threshold, we generate a single query: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) - # TODO: use ARRAY_BIND_THRESHOLD return { PlanNodeCategory.COLUMN: len(self.output), PlanNodeCategory.LITERAL: len(self.data) * len(self.output), diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..82451245e4c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, List, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) @@ -55,3 +56,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5886e11069..1ae08e8fde2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -36,6 +37,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 69db3f265ce..4381c4a2e22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -71,6 +72,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.lower, self.upper) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + *self.partition_spec, *self.order_spec, self.frame_spec + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec @@ -138,6 +147,11 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + self.window_function, self.window_spec + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW @@ -171,6 +185,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.default) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 836628345aa..8d16383a4ce 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -58,11 +58,6 @@ ) from snowflake.snowpark.session import Session -# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT -# in Snowflake. This is the limit where we start seeing compilation errors. -COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 -COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 - _logger = logging.getLogger(__name__) @@ -123,6 +118,12 @@ def __init__( self._query_generator = query_generator self.logical_plans = logical_plans self._parent_map = defaultdict(set) + self.complexity_score_lower_bound = ( + session.large_query_breakdown_complexity_bounds[0] + ) + self.complexity_score_upper_bound = ( + session.large_query_breakdown_complexity_bounds[1] + ) def apply(self) -> List[LogicalPlan]: if is_active_transaction(self.session): @@ -183,13 +184,13 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]: complexity_score = get_complexity_score(root.cumulative_node_complexity) _logger.debug(f"Complexity score for root {type(root)} is: {complexity_score}") - if complexity_score <= COMPLEXITY_SCORE_UPPER_BOUND: + if complexity_score <= self.complexity_score_upper_bound: # Skip optimization if the complexity score is within the upper bound. return [root] plans = [] # TODO: SNOW-1617634 Have a one pass algorithm to find the valid node for partitioning - while complexity_score > COMPLEXITY_SCORE_UPPER_BOUND: + while complexity_score > self.complexity_score_upper_bound: child = self._find_node_to_breakdown(root) if child is None: _logger.debug( @@ -277,7 +278,9 @@ def _is_node_valid_to_breakdown(self, node: LogicalPlan) -> Tuple[bool, int]: """ score = get_complexity_score(node.cumulative_node_complexity) valid_node = ( - COMPLEXITY_SCORE_LOWER_BOUND < score < COMPLEXITY_SCORE_UPPER_BOUND + self.complexity_score_lower_bound + < score + < self.complexity_score_upper_bound ) and self._is_node_pipeline_breaker(node) if valid_node: _logger.debug( diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 211b66820ec..3e6dba71be4 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -16,8 +16,6 @@ ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.compiler.large_query_breakdown import ( - COMPLEXITY_SCORE_LOWER_BOUND, - COMPLEXITY_SCORE_UPPER_BOUND, LargeQueryBreakdown, ) from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import ( @@ -128,10 +126,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: summary_value = { TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( - COMPLEXITY_SCORE_LOWER_BOUND, - COMPLEXITY_SCORE_UPPER_BOUND, - ), + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 223b6a1326f..be61a1ac924 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -11,6 +11,9 @@ class CompilationStageTelemetryField(Enum): "snowpark_large_query_breakdown_optimization_skipped" ) TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics" + TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = ( + "snowpark_large_query_breakdown_update_complexity_bounds" + ) # keys KEY_REASON = "reason" diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 8b9ef2acccb..025eb57c540 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -79,6 +79,20 @@ class TelemetryField(Enum): QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" QUERY_PLAN_COMPLEXITY = "query_plan_complexity" + # temp table cleanup + TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup" + NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned" + NUM_TEMP_TABLES_CREATED = "num_temp_tables_created" + TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled" + TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = ( + "snowpark_temp_table_cleanup_abnormal_exception" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = ( + "temp_table_cleanup_abnormal_exception_table_name" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = ( + "temp_table_cleanup_abnormal_exception_message" + ) # These DataFrame APIs call other DataFrame APIs @@ -374,7 +388,7 @@ def send_sql_simplifier_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.SQL_SIMPLIFIER_ENABLED.value: True, + TelemetryField.SQL_SIMPLIFIER_ENABLED.value: sql_simplifier_enabled, }, } self.send(message) @@ -428,7 +442,7 @@ def send_large_query_breakdown_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: True, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: value, }, } self.send(message) @@ -464,3 +478,60 @@ def send_large_query_optimization_skipped_telemetry( }, } self.send(message) + + def send_temp_table_cleanup_telemetry( + self, + session_id: str, + temp_table_cleaner_enabled: bool, + num_temp_tables_cleaned: int, + num_temp_tables_created: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled, + TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned, + TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created, + }, + } + self.send(message) + + def send_temp_table_cleanup_abnormal_exception_telemetry( + self, + session_id: str, + table_name: str, + exception_message: str, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message, + }, + } + self.send(message) + + def send_large_query_breakdown_update_complexity_bounds( + self, session_id: int, lower_bound: int, upper_bound: int + ): + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_DATA.value: { + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( + lower_bound, + upper_bound, + ), + }, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index b9055c6fc58..4fa17498d34 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -4,9 +4,7 @@ import logging import weakref from collections import defaultdict -from queue import Empty, Queue -from threading import Event, Thread -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable @@ -33,12 +31,6 @@ def __init__(self, session: "Session") -> None: # to its reference count for later temp table management # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) - # unused temp table will be put into the queue for cleanup - self.queue: Queue = Queue() - # thread for removing temp tables (running DROP TABLE sql) - self.cleanup_thread: Optional[Thread] = None - # An event managing a flag that indicates whether the cleaner is started - self.stop_event = Event() def add(self, table: SnowflakeTable) -> None: self.ref_count_map[table.name] += 1 @@ -46,61 +38,60 @@ def add(self, table: SnowflakeTable) -> None: # and this table will be dropped finally _ = weakref.finalize(table, self._delete_ref_count, table.name) - def _delete_ref_count(self, name: str) -> None: + def _delete_ref_count(self, name: str) -> None: # pragma: no cover """ Decrements the reference count of a temporary table, and if the count reaches zero, puts this table in the queue for cleanup. """ self.ref_count_map[name] -= 1 if self.ref_count_map[name] == 0: - self.ref_count_map.pop(name) - # clean up - self.queue.put(name) + if self.session.auto_clean_up_temp_table_enabled: + self.drop_table(name) elif self.ref_count_map[name] < 0: logging.debug( f"Unexpected reference count {self.ref_count_map[name]} for table {name}" ) - def process_cleanup(self) -> None: - while not self.stop_event.is_set(): - try: - # it's non-blocking after timeout and become interruptable with stop_event - # it will raise an `Empty` exception if queue is empty after timeout, - # then we catch this exception and avoid breaking loop - table_name = self.queue.get(timeout=1) - self.drop_table(table_name) - except Empty: - continue - - def drop_table(self, name: str) -> None: + def drop_table(self, name: str) -> None: # pragma: no cover common_log_text = f"temp table {name} in session {self.session.session_id}" - logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}") + logging.debug(f"Ready to drop {common_log_text}") + query_id = None try: - # TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported - with self.session._conn._conn.cursor() as cursor: - cursor.execute( - f"drop table if exists {name} /* internal query to drop unused temp table */", - _statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}, + async_job = self.session.sql( + f"drop table if exists {name} /* internal query to drop unused temp table */", + )._internal_collect_with_tag_no_telemetry( + block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name} + ) + query_id = async_job.query_id + logging.debug(f"Dropping {common_log_text} with query id {query_id}") + except Exception as ex: # pragma: no cover + warning_message = f"Failed to drop {common_log_text}, exception: {ex}" + logging.warning(warning_message) + if query_id is None: + # If no query_id is available, it means the query haven't been accepted by gs, + # and it won't occur in our job_etl_view, send a separate telemetry for recording. + self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry( + self.session.session_id, + name, + str(ex), ) - logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}") - except Exception as ex: - logging.warning( - f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}" - ) # pragma: no cover - - def is_alive(self) -> bool: - return self.cleanup_thread is not None and self.cleanup_thread.is_alive() - - def start(self) -> None: - self.stop_event.clear() - if not self.is_alive(): - self.cleanup_thread = Thread(target=self.process_cleanup) - self.cleanup_thread.start() def stop(self) -> None: """ - The cleaner will stop immediately and leave unfinished temp tables in the queue. + Stops the cleaner (no-op) and sends the telemetry. """ - self.stop_event.set() - if self.is_alive(): - self.cleanup_thread.join() + self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry( + self.session.session_id, + temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled, + num_temp_tables_cleaned=self.num_temp_tables_cleaned, + num_temp_tables_created=self.num_temp_tables_created, + ) + + @property + def num_temp_tables_created(self) -> int: + return len(self.ref_count_map) + + @property + def num_temp_tables_cleaned(self) -> int: + # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled + return sum(v == 0 for v in self.ref_count_map.values()) diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 6960d0eb629..8f9834630b7 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -88,13 +88,13 @@ # TODO: SNOW-851745 make sure add all Snowpark pandas API general functions from modin.pandas import plotting # type: ignore[import] +from modin.pandas.dataframe import DataFrame from modin.pandas.series import Series from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, ) -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.general import ( bdate_range, concat, @@ -185,10 +185,8 @@ modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP) -# For any method defined on Series/DF, add telemetry to it if it: -# 1. Is defined directly on an upstream class -# 2. The method name does not start with an _, or is in TELEMETRY_PRIVATE_METHODS - +# For any method defined on Series/DF, add telemetry to it if the method name does not start with an +# _, or the method is in TELEMETRY_PRIVATE_METHODS. This includes methods defined as an extension/override. for attr_name in dir(Series): # Since Series is defined in upstream Modin, all of its members were either defined upstream # or overridden by extension. @@ -197,11 +195,9 @@ try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name)) ) - -# TODO: SNOW-1063346 -# Since we still use the vendored version of DataFrame and the overrides for the top-level -# namespace haven't been performed yet, we need to set properties on the vendored version for attr_name in dir(DataFrame): + # Since DataFrame is defined in upstream Modin, all of its members were either defined upstream + # or overridden by extension. if not attr_name.startswith("_") or attr_name in TELEMETRY_PRIVATE_METHODS: register_dataframe_accessor(attr_name)( try_add_telemetry_to_attribute(attr_name, getattr(DataFrame, attr_name)) diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py index 6a34f50e42a..47d44835fe4 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py @@ -19,9 +19,12 @@ # existing code originally distributed by the Modin project, under the Apache License, # Version 2.0. -from modin.pandas.api.extensions import register_series_accessor +from modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) -from .extensions import register_dataframe_accessor, register_pd_accessor +from .extensions import register_pd_accessor __all__ = [ "register_dataframe_accessor", diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py index 45896292e74..05424c92072 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py @@ -86,49 +86,6 @@ def decorator(new_attr: Any): return decorator -def register_dataframe_accessor(name: str): - """ - Registers a dataframe attribute with the name provided. - This is a decorator that assigns a new attribute to DataFrame. It can be used - with the following syntax: - ``` - @register_dataframe_accessor("new_method") - def my_new_dataframe_method(*args, **kwargs): - # logic goes here - return - ``` - The new attribute can then be accessed with the name provided: - ``` - df.new_method(*my_args, **my_kwargs) - ``` - - If you want a property accessor, you must annotate with @property - after the call to this function: - ``` - @register_dataframe_accessor("new_prop") - @property - def my_new_dataframe_property(*args, **kwargs): - return _prop - ``` - - Parameters - ---------- - name : str - The name of the attribute to assign to DataFrame. - Returns - ------- - decorator - Returns the decorator function. - """ - import snowflake.snowpark.modin.pandas as pd - - return _set_attribute_on_obj( - name, - pd.dataframe._DATAFRAME_EXTENSIONS_, - pd.dataframe.DataFrame, - ) - - def register_pd_accessor(name: str): """ Registers a pd namespace attribute with the name provided. diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py deleted file mode 100644 index 83893e83e9c..00000000000 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ /dev/null @@ -1,3511 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -# Licensed to Modin Development Team under one or more contributor license agreements. -# See the NOTICE file distributed with this work for additional information regarding -# copyright ownership. The Modin Development Team licenses this file to you under the -# Apache License, Version 2.0 (the "License"); you may not use this file except in -# compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under -# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific language -# governing permissions and limitations under the License. - -# Code in this file may constitute partial or total reimplementation, or modification of -# existing code originally distributed by the Modin project, under the Apache License, -# Version 2.0. - -"""Module houses ``DataFrame`` class, that is distributed version of ``pandas.DataFrame``.""" - -from __future__ import annotations - -import collections -import datetime -import functools -import itertools -import re -import sys -import warnings -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from logging import getLogger -from typing import IO, Any, Callable, Literal - -import numpy as np -import pandas -from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor -from modin.pandas.base import BasePandasDataset - -# from . import _update_engine -from modin.pandas.iterator import PartitionIterator -from modin.pandas.series import Series -from pandas._libs.lib import NoDefault, no_default -from pandas._typing import ( - AggFuncType, - AnyArrayLike, - Axes, - Axis, - CompressionOptions, - FilePath, - FillnaOptions, - IgnoreRaise, - IndexLabel, - Level, - PythonFuncType, - Renamer, - Scalar, - StorageOptions, - Suffixes, - WriteBuffer, -) -from pandas.core.common import apply_if_callable, is_bool_indexer -from pandas.core.dtypes.common import ( - infer_dtype_from_object, - is_bool_dtype, - is_dict_like, - is_list_like, - is_numeric_dtype, -) -from pandas.core.dtypes.inference import is_hashable, is_integer -from pandas.core.indexes.frozen import FrozenList -from pandas.io.formats.printing import pprint_thing -from pandas.util._validators import validate_bool_kwarg - -from snowflake.snowpark.modin import pandas as pd -from snowflake.snowpark.modin.pandas.groupby import ( - DataFrameGroupBy, - validate_groupby_args, -) -from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( - SnowparkPandasRowPartitionIterator, -) -from snowflake.snowpark.modin.pandas.utils import ( - create_empty_native_pandas_frame, - from_non_pandas, - from_pandas, - is_scalar, - raise_if_native_pandas_objects, - replace_external_data_keys_with_empty_pandas_series, - replace_external_data_keys_with_query_compiler, -) -from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated -from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike -from snowflake.snowpark.modin.plugin.utils.error_message import ( - ErrorMessage, - dataframe_not_implemented, -) -from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP -from snowflake.snowpark.modin.plugin.utils.warning_message import ( - SET_DATAFRAME_ATTRIBUTE_WARNING, - WarningMessage, -) -from snowflake.snowpark.modin.utils import _inherit_docstrings, hashable, to_pandas -from snowflake.snowpark.udf import UserDefinedFunction - -logger = getLogger(__name__) - -DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( - "Currently do not support Series or list-like keys with range-like values" -) - -DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( - "Currently do not support assigning a slice value as if it's a scalar value" -) - -DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( - "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " - "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " - "can work on the entire DataFrame in one shot." -) - -# Dictionary of extensions assigned to this class -_DATAFRAME_EXTENSIONS_ = {} - - -@_inherit_docstrings( - pandas.DataFrame, - excluded=[ - pandas.DataFrame.flags, - pandas.DataFrame.cov, - pandas.DataFrame.merge, - pandas.DataFrame.reindex, - pandas.DataFrame.to_parquet, - pandas.DataFrame.fillna, - ], - apilink="pandas.DataFrame", -) -class DataFrame(BasePandasDataset): - _pandas_class = pandas.DataFrame - - def __init__( - self, - data=None, - index=None, - columns=None, - dtype=None, - copy=None, - query_compiler=None, - ) -> None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Siblings are other dataframes that share the same query compiler. We - # use this list to update inplace when there is a shallow copy. - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native - - self._siblings = [] - - # Engine.subscribe(_update_engine) - if isinstance(data, (DataFrame, Series)): - self._query_compiler = data._query_compiler.copy() - if index is not None and any(i not in data.index for i in index): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if isinstance(data, Series): - # We set the column name if it is not in the provided Series - if data.name is None: - self.columns = [0] if columns is None else columns - # If the columns provided are not in the named Series, pandas clears - # the DataFrame and sets columns to the columns provided. - elif columns is not None and data.name not in columns: - self._query_compiler = from_pandas( - self.__constructor__(columns=columns) - )._query_compiler - if index is not None: - self._query_compiler = data.loc[index]._query_compiler - elif columns is None and index is None: - data._add_sibling(self) - else: - if columns is not None and any(i not in data.columns for i in columns): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if index is None: - index = slice(None) - if columns is None: - columns = slice(None) - self._query_compiler = data.loc[index, columns]._query_compiler - - # Check type of data and use appropriate constructor - elif query_compiler is None: - distributed_frame = from_non_pandas(data, index, columns, dtype) - if distributed_frame is not None: - self._query_compiler = distributed_frame._query_compiler - return - - if isinstance(data, pandas.Index): - pass - elif is_list_like(data) and not is_dict_like(data): - old_dtype = getattr(data, "dtype", None) - values = [ - obj._to_pandas() if isinstance(obj, Series) else obj for obj in data - ] - if isinstance(data, np.ndarray): - data = np.array(values, dtype=old_dtype) - else: - try: - data = type(data)(values, dtype=old_dtype) - except TypeError: - data = values - elif is_dict_like(data) and not isinstance( - data, (pandas.Series, Series, pandas.DataFrame, DataFrame) - ): - if columns is not None: - data = {key: value for key, value in data.items() if key in columns} - - if len(data) and all(isinstance(v, Series) for v in data.values()): - from .general import concat - - new_qc = concat( - data.values(), axis=1, keys=data.keys() - )._query_compiler - - if dtype is not None: - new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) - if index is not None: - new_qc = new_qc.reindex( - axis=0, labels=try_convert_index_to_native(index) - ) - if columns is not None: - new_qc = new_qc.reindex( - axis=1, labels=try_convert_index_to_native(columns) - ) - - self._query_compiler = new_qc - return - - data = { - k: v._to_pandas() if isinstance(v, Series) else v - for k, v in data.items() - } - pandas_df = pandas.DataFrame( - data=try_convert_index_to_native(data), - index=try_convert_index_to_native(index), - columns=try_convert_index_to_native(columns), - dtype=dtype, - copy=copy, - ) - self._query_compiler = from_pandas(pandas_df)._query_compiler - else: - self._query_compiler = query_compiler - - def __repr__(self): - """ - Return a string representation for a particular ``DataFrame``. - - Returns - ------- - str - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - num_rows = pandas.get_option("display.max_rows") or len(self) - # see _repr_html_ for comment, allow here also all column behavior - num_cols = pandas.get_option("display.max_columns") or len(self.columns) - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") - result = repr(repr_df) - - # if truncated, add shape information - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # The split here is so that we don't repr pandas row lengths. - return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( - row_count, col_count - ) - else: - return result - - def _repr_html_(self): # pragma: no cover - """ - Return a html representation for a particular ``DataFrame``. - - Returns - ------- - str - - Notes - ----- - Supports pandas `display.max_rows` and `display.max_columns` options. - """ - num_rows = pandas.get_option("display.max_rows") or 60 - # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow - # here value=0 which means display all columns. - num_cols = pandas.get_option("display.max_columns") - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols) - result = repr_df._repr_html_() - - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # We split so that we insert our correct dataframe dimensions. - return ( - result.split("

")[0] - + f"

{row_count} rows × {col_count} columns

\n" - ) - else: - return result - - def _get_columns(self) -> pandas.Index: - """ - Get the columns for this Snowpark pandas ``DataFrame``. - - Returns - ------- - Index - The all columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.columns - - def _set_columns(self, new_columns: Axes) -> None: - """ - Set the columns for this Snowpark pandas ``DataFrame``. - - Parameters - ---------- - new_columns : - The new columns to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_inplace( - new_query_compiler=self._query_compiler.set_columns(new_columns) - ) - - columns = property(_get_columns, _set_columns) - - @property - def ndim(self) -> int: - return 2 - - def drop_duplicates( - self, subset=None, keep="first", inplace=False, ignore_index=False - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - """ - Return ``DataFrame`` with duplicate rows removed. - """ - return super().drop_duplicates( - subset=subset, keep=keep, inplace=inplace, ignore_index=ignore_index - ) - - def dropna( - self, - *, - axis: Axis = 0, - how: str | NoDefault = no_default, - thresh: int | NoDefault = no_default, - subset: IndexLabel = None, - inplace: bool = False, - ): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super()._dropna( - axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace - ) - - @property - def dtypes(self): # noqa: RT01, D200 - """ - Return the dtypes in the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.dtypes - - def duplicated( - self, subset: Hashable | Sequence[Hashable] = None, keep: DropKeep = "first" - ): - """ - Return boolean ``Series`` denoting duplicate rows. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - df = self[subset] if subset is not None else self - new_qc = df._query_compiler.duplicated(keep=keep) - duplicates = self._reduce_dimension(new_qc) - # remove Series name which was assigned automatically by .apply in QC - # this is pandas behavior, i.e., if duplicated result is a series, no name is returned - duplicates.name = None - return duplicates - - @property - def empty(self) -> bool: - """ - Indicate whether ``DataFrame`` is empty. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self.columns) == 0 or len(self) == 0 - - @property - def axes(self): - """ - Return a list representing the axes of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return [self.index, self.columns] - - @property - def shape(self) -> tuple[int, int]: - """ - Return a tuple representing the dimensionality of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self), len(self.columns) - - def add_prefix(self, prefix): - """ - Prefix labels with string `prefix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string prefix values into str and adds it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(prefix), substring_type="prefix", axis=1 - ) - ) - - def add_suffix(self, suffix): - """ - Suffix labels with string `suffix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string suffix values into str and appends it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(suffix), substring_type="suffix", axis=1 - ) - ) - - @dataframe_not_implemented() - def map( - self, func, na_action: str | None = None, **kwargs - ) -> DataFrame: # pragma: no cover - if not callable(func): - raise ValueError(f"'{type(func)}' object is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.map(func, na_action=na_action, **kwargs) - ) - - def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not callable(func): - raise TypeError(f"{func} is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.applymap( - func, na_action=na_action, **kwargs - ) - ) - - def aggregate( - self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().aggregate(func, axis, *args, **kwargs) - - agg = aggregate - - def apply( - self, - func: AggFuncType | UserDefinedFunction, - axis: Axis = 0, - raw: bool = False, - result_type: Literal["expand", "reduce", "broadcast"] | None = None, - args=(), - **kwargs, - ): - """ - Apply a function along an axis of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - query_compiler = self._query_compiler.apply( - func, - axis, - raw=raw, - result_type=result_type, - args=args, - **kwargs, - ) - if not isinstance(query_compiler, type(self._query_compiler)): - # A scalar was returned - return query_compiler - - # If True, it is an unamed series. - # Theoretically, if df.apply returns a Series, it will only be an unnamed series - # because the function is supposed to be series -> scalar. - if query_compiler._modin_frame.is_unnamed_series(): - return Series(query_compiler=query_compiler) - else: - return self.__constructor__(query_compiler=query_compiler) - - def groupby( - self, - by=None, - axis: Axis | NoDefault = no_default, - level: IndexLabel | None = None, - as_index: bool = True, - sort: bool = True, - group_keys: bool = True, - observed: bool | NoDefault = no_default, - dropna: bool = True, - ): - """ - Group ``DataFrame`` using a mapper or by a ``Series`` of columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if axis is not no_default: - axis = self._get_axis_number(axis) - if axis == 1: - warnings.warn( - "DataFrame.groupby with axis=1 is deprecated. Do " - + "`frame.T.groupby(...)` without axis instead.", - FutureWarning, - stacklevel=1, - ) - else: - warnings.warn( - "The 'axis' keyword in DataFrame.groupby is deprecated and " - + "will be removed in a future version.", - FutureWarning, - stacklevel=1, - ) - else: - axis = 0 - - validate_groupby_args(by, level, observed) - - axis = self._get_axis_number(axis) - - if axis != 0 and as_index is False: - raise ValueError("as_index=False only valid for axis=0") - - idx_name = None - - if ( - not isinstance(by, Series) - and is_list_like(by) - and len(by) == 1 - # if by is a list-like of (None,), we have to keep it as a list because - # None may be referencing a column or index level whose label is - # `None`, and by=None wold mean that there is no `by` param. - and by[0] is not None - ): - by = by[0] - - if hashable(by) and ( - not callable(by) and not isinstance(by, (pandas.Grouper, FrozenList)) - ): - idx_name = by - elif isinstance(by, Series): - idx_name = by.name - if by._parent is self: - # if the SnowSeries comes from the current dataframe, - # convert it to labels directly for easy processing - by = by.name - elif is_list_like(by): - if axis == 0 and all( - ( - (hashable(o) and (o in self)) - or isinstance(o, Series) - or (is_list_like(o) and len(o) == len(self.shape[axis])) - ) - for o in by - ): - # plit 'by's into those that belongs to the self (internal_by) - # and those that doesn't (external_by). For SnowSeries that belongs - # to current DataFrame, we convert it to labels for easy process. - internal_by, external_by = [], [] - - for current_by in by: - if hashable(current_by): - internal_by.append(current_by) - elif isinstance(current_by, Series): - if current_by._parent is self: - internal_by.append(current_by.name) - else: - external_by.append(current_by) # pragma: no cover - else: - external_by.append(current_by) - - by = internal_by + external_by - - return DataFrameGroupBy( - self, - by, - axis, - level, - as_index, - sort, - group_keys, - idx_name, - observed=observed, - dropna=dropna, - ) - - def keys(self): # noqa: RT01, D200 - """ - Get columns of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns - - def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 - """ - Transpose index and columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy: - WarningMessage.ignored_argument( - operation="transpose", - argument="copy", - message="Transpose ignore copy argument in Snowpark pandas API", - ) - - if args: - WarningMessage.ignored_argument( - operation="transpose", - argument="args", - message="Transpose ignores args in Snowpark pandas API", - ) - - return self.__constructor__(query_compiler=self._query_compiler.transpose()) - - T = property(transpose) - - def add( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `add`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "add", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def assign(self, **kwargs): # noqa: PR01, RT01, D200 - """ - Assign new columns to a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - df = self.copy() - for k, v in kwargs.items(): - if callable(v): - df[k] = v(df) - else: - df[k] = v - return df - - @dataframe_not_implemented() - def boxplot( - self, - column=None, - by=None, - ax=None, - fontsize=None, - rot=0, - grid=True, - figsize=None, - layout=None, - return_type=None, - backend=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make a box plot from ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return to_pandas(self).boxplot( - column=column, - by=by, - ax=ax, - fontsize=fontsize, - rot=rot, - grid=grid, - figsize=figsize, - layout=layout, - return_type=return_type, - backend=backend, - **kwargs, - ) - - @dataframe_not_implemented() - def combine( - self, other, func, fill_value=None, overwrite=True - ): # noqa: PR01, RT01, D200 - """ - Perform column-wise combine with another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().combine(other, func, fill_value=fill_value, overwrite=overwrite) - - def compare( - self, - other, - align_axis=1, - keep_shape: bool = False, - keep_equal: bool = False, - result_names=("self", "other"), - ) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Compare to another ``DataFrame`` and show the differences. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - raise TypeError(f"Cannot compare DataFrame to {type(other)}") - other = self._validate_other(other, 0, compare_index=True) - return self.__constructor__( - query_compiler=self._query_compiler.compare( - other, - align_axis=align_axis, - keep_shape=keep_shape, - keep_equal=keep_equal, - result_names=result_names, - ) - ) - - def corr( - self, - method: str | Callable = "pearson", - min_periods: int | None = None, - numeric_only: bool = False, - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - corr_df = self - if numeric_only: - corr_df = self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - return self.__constructor__( - query_compiler=corr_df._query_compiler.corr( - method=method, - min_periods=min_periods, - ) - ) - - @dataframe_not_implemented() - def corrwith( - self, other, axis=0, drop=False, method="pearson", numeric_only=False - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, DataFrame): - other = other._query_compiler.to_pandas() - return self._default_to_pandas( - pandas.DataFrame.corrwith, - other, - axis=axis, - drop=drop, - method=method, - numeric_only=numeric_only, - ) - - @dataframe_not_implemented() - def cov( - self, - min_periods: int | None = None, - ddof: int | None = 1, - numeric_only: bool = False, - ): - """ - Compute pairwise covariance of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.cov( - min_periods=min_periods, - ddof=ddof, - numeric_only=numeric_only, - ) - ) - - @dataframe_not_implemented() - def dot(self, other): # noqa: PR01, RT01, D200 - """ - Compute the matrix multiplication between the ``DataFrame`` and `other`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if isinstance(other, BasePandasDataset): - common = self.columns.union(other.index) - if len(common) > len(self.columns) or len(common) > len( - other - ): # pragma: no cover - raise ValueError("Matrices are not aligned") - - if isinstance(other, DataFrame): - return self.__constructor__( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - else: - return self._reduce_dimension( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - - other = np.asarray(other) - if self.shape[1] != other.shape[0]: - raise ValueError( - f"Dot product shape mismatch, {self.shape} vs {other.shape}" - ) - - if len(other.shape) > 1: - return self.__constructor__( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - return self._reduce_dimension( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Perform equality comparison of ``DataFrame`` and `other` (binary operator `eq`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("eq", other, axis=axis, level=level) - - def equals(self, other) -> bool: # noqa: PR01, RT01, D200 - """ - Test whether two objects contain the same elements. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, pandas.DataFrame): - # Copy into a Modin DataFrame to simplify logic below - other = self.__constructor__(other) - - if ( - type(self) is not type(other) - or not self.index.equals(other.index) - or not self.columns.equals(other.columns) - ): - return False - - result = self.__constructor__( - query_compiler=self._query_compiler.equals(other._query_compiler) - ) - return result.all(axis=None) - - def _update_var_dicts_in_kwargs(self, expr, kwargs): - """ - Copy variables with "@" prefix in `local_dict` and `global_dict` keys of kwargs. - - Parameters - ---------- - expr : str - The expression string to search variables with "@" prefix. - kwargs : dict - See the documentation for eval() for complete details on the keyword arguments accepted by query(). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if "@" not in expr: - return - frame = sys._getframe() - try: - f_locals = frame.f_back.f_back.f_back.f_back.f_locals - f_globals = frame.f_back.f_back.f_back.f_back.f_globals - finally: - del frame - local_names = set(re.findall(r"@([\w]+)", expr)) - local_dict = {} - global_dict = {} - - for name in local_names: - for dct_out, dct_in in ((local_dict, f_locals), (global_dict, f_globals)): - try: - dct_out[name] = dct_in[name] - except KeyError: - pass - - if local_dict: - local_dict.update(kwargs.get("local_dict") or {}) - kwargs["local_dict"] = local_dict - if global_dict: - global_dict.update(kwargs.get("global_dict") or {}) - kwargs["global_dict"] = global_dict - - @dataframe_not_implemented() - def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Evaluate a string describing operations on ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - self._update_var_dicts_in_kwargs(expr, kwargs) - new_query_compiler = self._query_compiler.eval(expr, **kwargs) - return_type = type( - pandas.DataFrame(columns=self.columns) - .astype(self.dtypes) - .eval(expr, **kwargs) - ).__name__ - if return_type == type(self).__name__: - return self._create_or_update_from_compiler(new_query_compiler, inplace) - else: - if inplace: - raise ValueError("Cannot operate inplace if there is no assignment") - return getattr(sys.modules[self.__module__], return_type)( - query_compiler=new_query_compiler - ) - - def fillna( - self, - value: Hashable | Mapping | Series | DataFrame = None, - *, - method: FillnaOptions | None = None, - axis: Axis | None = None, - inplace: bool = False, - limit: int | None = None, - downcast: dict | None = None, - ) -> DataFrame | None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().fillna( - self_is_series=False, - value=value, - method=method, - axis=axis, - inplace=inplace, - limit=limit, - downcast=downcast, - ) - - def floordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `floordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "floordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @classmethod - @dataframe_not_implemented() - def from_dict( - cls, data, orient="columns", dtype=None, columns=None - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Construct ``DataFrame`` from dict of array-like or dicts. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_dict( - data, orient=orient, dtype=dtype, columns=columns - ) - ) - - @classmethod - @dataframe_not_implemented() - def from_records( - cls, - data, - index=None, - exclude=None, - columns=None, - coerce_float=False, - nrows=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert structured or record ndarray to ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_records( - data, - index=index, - exclude=exclude, - columns=columns, - coerce_float=coerce_float, - nrows=nrows, - ) - ) - - def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ge", other, axis=axis, level=level) - - def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("gt", other, axis=axis, level=level) - - @dataframe_not_implemented() - def hist( - self, - column=None, - by=None, - grid=True, - xlabelsize=None, - xrot=None, - ylabelsize=None, - yrot=None, - ax=None, - sharex=False, - sharey=False, - figsize=None, - layout=None, - bins=10, - **kwds, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Make a histogram of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.hist, - column=column, - by=by, - grid=grid, - xlabelsize=xlabelsize, - xrot=xrot, - ylabelsize=ylabelsize, - yrot=yrot, - ax=ax, - sharex=sharex, - sharey=sharey, - figsize=figsize, - layout=layout, - bins=bins, - **kwds, - ) - - def info( - self, - verbose: bool | None = None, - buf: IO[str] | None = None, - max_cols: int | None = None, - memory_usage: bool | str | None = None, - show_counts: bool | None = None, - null_counts: bool | None = None, - ): # noqa: PR01, D200 - """ - Print a concise summary of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def put_str(src, output_len=None, spaces=2): - src = str(src) - return src.ljust(output_len if output_len else len(src)) + " " * spaces - - def format_size(num): - for x in ["bytes", "KB", "MB", "GB", "TB"]: - if num < 1024.0: - return f"{num:3.1f} {x}" - num /= 1024.0 - return f"{num:3.1f} PB" - - output = [] - - type_line = str(type(self)) - index_line = "SnowflakeIndex" - columns = self.columns - columns_len = len(columns) - dtypes = self.dtypes - dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" - - if max_cols is None: - max_cols = 100 - - exceeds_info_cols = columns_len > max_cols - - if buf is None: - buf = sys.stdout - - if null_counts is None: - null_counts = not exceeds_info_cols - - if verbose is None: - verbose = not exceeds_info_cols - - if null_counts and verbose: - # We're gonna take items from `non_null_count` in a loop, which - # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here - # that will be faster. - non_null_count = self.count()._to_pandas() - - if memory_usage is None: - memory_usage = True - - def get_header(spaces=2): - output = [] - head_label = " # " - column_label = "Column" - null_label = "Non-Null Count" - dtype_label = "Dtype" - non_null_label = " non-null" - delimiter = "-" - - lengths = {} - lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) - lengths["column"] = max( - len(column_label), max(len(pprint_thing(col)) for col in columns) - ) - lengths["dtype"] = len(dtype_label) - dtype_spaces = ( - max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) - - lengths["dtype"] - ) - - header = put_str(head_label, lengths["head"]) + put_str( - column_label, lengths["column"] - ) - if null_counts: - lengths["null"] = max( - len(null_label), - max(len(pprint_thing(x)) for x in non_null_count) - + len(non_null_label), - ) - header += put_str(null_label, lengths["null"]) - header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) - - output.append(header) - - delimiters = put_str(delimiter * lengths["head"]) + put_str( - delimiter * lengths["column"] - ) - if null_counts: - delimiters += put_str(delimiter * lengths["null"]) - delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) - output.append(delimiters) - - return output, lengths - - output.extend([type_line, index_line]) - - def verbose_repr(output): - columns_line = f"Data columns (total {len(columns)} columns):" - header, lengths = get_header() - output.extend([columns_line, *header]) - for i, col in enumerate(columns): - i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) - - to_append = put_str(f" {i}", lengths["head"]) + put_str( - col_s, lengths["column"] - ) - if null_counts: - non_null = pprint_thing(non_null_count[col]) - to_append += put_str(f"{non_null} non-null", lengths["null"]) - to_append += put_str(dtype, lengths["dtype"], spaces=0) - output.append(to_append) - - def non_verbose_repr(output): - output.append(columns._summary(name="Columns")) - - if verbose: - verbose_repr(output) - else: - non_verbose_repr(output) - - output.append(dtypes_line) - - if memory_usage: - deep = memory_usage == "deep" - mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() - mem_line = f"memory usage: {format_size(mem_usage_bytes)}" - - output.append(mem_line) - - output.append("") - buf.write("\n".join(output)) - - def insert( - self, - loc: int, - column: Hashable, - value: Scalar | AnyArrayLike, - allow_duplicates: bool | NoDefault = no_default, - ) -> None: - """ - Insert column into ``DataFrame`` at specified location. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - raise_if_native_pandas_objects(value) - if allow_duplicates is no_default: - allow_duplicates = False - if not allow_duplicates and column in self.columns: - raise ValueError(f"cannot insert {column}, already exists") - - if not isinstance(loc, int): - raise TypeError("loc must be int") - - # If columns labels are multilevel, we implement following behavior (this is - # name native pandas): - # Case 1: if 'column' is tuple it's length must be same as number of levels - # otherwise raise error. - # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in - # empty strings to match the length of column levels in self frame. - if self.columns.nlevels > 1: - if isinstance(column, tuple) and len(column) != self.columns.nlevels: - # same error as native pandas. - raise ValueError("Item must have length equal to number of levels.") - if not isinstance(column, tuple): - # Fill empty strings to match length of levels - suffix = [""] * (self.columns.nlevels - 1) - column = tuple([column] + suffix) - - # Dictionary keys are treated as index column and this should be joined with - # index of target dataframe. This behavior is similar to 'value' being DataFrame - # or Series, so we simply create Series from dict data here. - if isinstance(value, dict): - value = Series(value, name=column) - - if isinstance(value, DataFrame) or ( - isinstance(value, np.ndarray) and len(value.shape) > 1 - ): - # Supported numpy array shapes are - # 1. (N, ) -> Ex. [1, 2, 3] - # 2. (N, 1) -> Ex> [[1], [2], [3]] - if value.shape[1] != 1: - if isinstance(value, DataFrame): - # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin - raise ValueError( - f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." - ) - else: - raise ValueError( - f"Expected a 1D array, got an array with shape {value.shape}" - ) - # Change numpy array shape from (N, 1) to (N, ) - if isinstance(value, np.ndarray): - value = value.squeeze(axis=1) - - if ( - is_list_like(value) - and not isinstance(value, (Series, DataFrame)) - and len(value) != self.shape[0] - and not 0 == self.shape[0] # dataframe holds no rows - ): - raise ValueError( - "Length of values ({}) does not match length of index ({})".format( - len(value), len(self) - ) - ) - if not -len(self.columns) <= loc <= len(self.columns): - raise IndexError( - f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" - ) - elif loc < 0: - raise ValueError("unbounded slice") - - join_on_index = False - if isinstance(value, (Series, DataFrame)): - value = value._query_compiler - join_on_index = True - elif is_list_like(value): - value = Series(value, name=column)._query_compiler - - new_query_compiler = self._query_compiler.insert( - loc, column, value, join_on_index - ) - # In pandas, 'insert' operation is always inplace. - self._update_inplace(new_query_compiler=new_query_compiler) - - @dataframe_not_implemented() - def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Fill NaN values using an interpolation method. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.interpolate, - method=method, - axis=axis, - limit=limit, - inplace=inplace, - limit_direction=limit_direction, - limit_area=limit_area, - downcast=downcast, - **kwargs, - ) - - def iterrows(self) -> Iterator[tuple[Hashable, Series]]: - """ - Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def iterrow_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - # Raise warning message since iterrows is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") - ) - - partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) - yield from partition_iterator - - def items(self): # noqa: D200 - """ - Iterate over (column name, ``Series``) pairs. - """ - - def items_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - partition_iterator = PartitionIterator(self, 1, items_builder) - yield from partition_iterator - - def itertuples( - self, index: bool = True, name: str | None = "Pandas" - ) -> Iterable[tuple[Any, ...]]: - """ - Iterate over ``DataFrame`` rows as ``namedtuple``-s. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - def itertuples_builder(s): - """Return the next namedtuple.""" - # s is the Series of values in the current row. - fields = [] # column names - data = [] # values under each column - - if index: - data.append(s.name) - fields.append("Index") - - # Fill column names and values. - fields.extend(list(self.columns)) - data.extend(s) - - if name is not None: - # Creating the namedtuple. - itertuple = collections.namedtuple(name, fields, rename=True) - return itertuple._make(data) - - # When the name is None, return a regular tuple. - return tuple(data) - - # Raise warning message since itertuples is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") - ) - return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) - - def join( - self, - other: DataFrame | Series | Iterable[DataFrame | Series], - on: IndexLabel | None = None, - how: str = "left", - lsuffix: str = "", - rsuffix: str = "", - sort: bool = False, - validate: str | None = None, - ) -> DataFrame: - """ - Join columns of another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - for o in other if isinstance(other, list) else [other]: - raise_if_native_pandas_objects(o) - - # Similar to native pandas we implement 'join' using 'pd.merge' method. - # Following code is copied from native pandas (with few changes explained below) - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 - if isinstance(other, Series): - # Same error as native pandas. - if other.name is None: - raise ValueError("Other Series must have a name") - other = DataFrame(other) - elif is_list_like(other): - if any([isinstance(o, Series) and o.name is None for o in other]): - raise ValueError("Other Series must have a name") - - if isinstance(other, DataFrame): - if how == "cross": - return pd.merge( - self, - other, - how=how, - on=on, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - return pd.merge( - self, - other, - left_on=on, - how=how, - left_index=on is None, - right_index=True, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - else: # List of DataFrame/Series - # Same error as native pandas. - if on is not None: - raise ValueError( - "Joining multiple DataFrames only supported for joining on index" - ) - - # Same error as native pandas. - if rsuffix or lsuffix: - raise ValueError( - "Suffixes not supported when joining multiple DataFrames" - ) - - # NOTE: These are not the differences between Snowpark pandas API and pandas behavior - # these are differences between native pandas join behavior when join - # frames have unique index or not. - - # In native pandas logic to join multiple DataFrames/Series is data - # dependent. Under the hood it will either use 'concat' or 'merge' API - # Case 1. If all objects being joined have unique index use 'concat' (axis=1) - # Case 2. Otherwise use 'merge' API by looping through objects left to right. - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 - - # Even though concat (axis=1) and merge are very similar APIs they have - # some differences which leads to inconsistent behavior in native pandas. - # 1. Treatment of un-named Series - # Case #1: Un-named series is allowed in concat API. Objects are joined - # successfully by assigning a number as columns name (see 'concat' API - # documentation for details on treatment of un-named series). - # Case #2: It raises 'ValueError: Other Series must have a name' - - # 2. how='right' - # Case #1: 'concat' API doesn't support right join. It raises - # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' - # Case #2: Merges successfully. - - # 3. Joining frames with duplicate labels but no conflict with other frames - # Example: self = DataFrame(... columns=["A", "B"]) - # other = [DataFrame(... columns=["C", "C"])] - # Case #1: 'ValueError: Indexes have overlapping values' - # Case #2: Merged successfully. - - # In addition to this, native pandas implementation also leads to another - # type of inconsistency where left.join(other, ...) and - # left.join([other], ...) might behave differently for cases mentioned - # above. - # Example: - # import pandas as pd - # df = pd.DataFrame({"a": [4, 5]}) - # other = pd.Series([1, 2]) - # df.join([other]) # this is successful - # df.join(other) # this raises 'ValueError: Other Series must have a name' - - # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API - # to join multiple DataFrame/Series. So always follow the behavior - # documented as Case #2 above. - - joined = self - for frame in other: - if isinstance(frame, DataFrame): - overlapping_cols = set(joined.columns).intersection( - set(frame.columns) - ) - if len(overlapping_cols) > 0: - # Native pandas raises: 'Indexes have overlapping values' - # We differ slightly from native pandas message to make it more - # useful to users. - raise ValueError( - f"Join dataframes have overlapping column labels: {overlapping_cols}" - ) - joined = pd.merge( - joined, - frame, - how=how, - left_index=True, - right_index=True, - validate=validate, - sort=sort, - suffixes=(None, None), - ) - return joined - - def isna(self): - return super().isna() - - def isnull(self): - return super().isnull() - - @dataframe_not_implemented() - def isetitem(self, loc, value): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.isetitem, - loc=loc, - value=value, - ) - - def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("le", other, axis=axis, level=level) - - def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("lt", other, axis=axis, level=level) - - def melt( - self, - id_vars=None, - value_vars=None, - var_name=None, - value_name="value", - col_level=None, - ignore_index=True, - ): # noqa: PR01, RT01, D200 - """ - Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if id_vars is None: - id_vars = [] - if not is_list_like(id_vars): - id_vars = [id_vars] - if value_vars is None: - # Behavior of Index.difference changed in 2.2.x - # https://github.com/pandas-dev/pandas/pull/55113 - # This change needs upstream to Modin: - # https://github.com/modin-project/modin/issues/7206 - value_vars = self.columns.drop(id_vars) - if var_name is None: - columns_name = self._query_compiler.get_index_name(axis=1) - var_name = columns_name if columns_name is not None else "variable" - return self.__constructor__( - query_compiler=self._query_compiler.melt( - id_vars=id_vars, - value_vars=value_vars, - var_name=var_name, - value_name=value_name, - col_level=col_level, - ignore_index=ignore_index, - ) - ) - - @dataframe_not_implemented() - def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 - """ - Return the memory usage of each column in bytes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if index: - result = self._reduce_dimension( - self._query_compiler.memory_usage(index=False, deep=deep) - ) - index_value = self.index.memory_usage(deep=deep) - return pd.concat( - [Series(index_value, index=["Index"]), result] - ) # pragma: no cover - return super().memory_usage(index=index, deep=deep) - - def merge( - self, - right: DataFrame | Series, - how: str = "inner", - on: IndexLabel | None = None, - left_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - right_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - left_index: bool = False, - right_index: bool = False, - sort: bool = False, - suffixes: Suffixes = ("_x", "_y"), - copy: bool = True, - indicator: bool = False, - validate: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Raise error if native pandas objects are passed. - raise_if_native_pandas_objects(right) - - if isinstance(right, Series) and right.name is None: - raise ValueError("Cannot merge a Series without a name") - if not isinstance(right, (Series, DataFrame)): - raise TypeError( - f"Can only merge Series or DataFrame objects, a {type(right)} was passed" - ) - - if isinstance(right, Series): - right_column_nlevels = ( - len(right.name) if isinstance(right.name, tuple) else 1 - ) - else: - right_column_nlevels = right.columns.nlevels - if self.columns.nlevels != right_column_nlevels: - # This is deprecated in native pandas. We raise explicit error for this. - raise ValueError( - "Can not merge objects with different column levels." - + f" ({self.columns.nlevels} levels on the left," - + f" {right_column_nlevels} on the right)" - ) - - # Merge empty native pandas dataframes for error checking. Otherwise, it will - # require a lot of logic to be written. This takes care of raising errors for - # following scenarios: - # 1. Only 'left_index' is set to True. - # 2. Only 'right_index is set to True. - # 3. Only 'left_on' is provided. - # 4. Only 'right_on' is provided. - # 5. 'on' and 'left_on' both are provided - # 6. 'on' and 'right_on' both are provided - # 7. 'on' and 'left_index' both are provided - # 8. 'on' and 'right_index' both are provided - # 9. 'left_on' and 'left_index' both are provided - # 10. 'right_on' and 'right_index' both are provided - # 11. Length mismatch between 'left_on' and 'right_on' - # 12. 'left_index' is not a bool - # 13. 'right_index' is not a bool - # 14. 'on' is not None and how='cross' - # 15. 'left_on' is not None and how='cross' - # 16. 'right_on' is not None and how='cross' - # 17. 'left_index' is True and how='cross' - # 18. 'right_index' is True and how='cross' - # 19. Unknown label in 'on', 'left_on' or 'right_on' - # 20. Provided 'suffixes' is not sufficient to resolve conflicts. - # 21. Merging on column with duplicate labels. - # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} - # 23. conflict with existing labels for array-like join key - # 24. 'indicator' argument is not bool or str - # 25. indicator column label conflicts with existing data labels - create_empty_native_pandas_frame(self).merge( - create_empty_native_pandas_frame(right), - on=on, - how=how, - left_on=replace_external_data_keys_with_empty_pandas_series(left_on), - right_on=replace_external_data_keys_with_empty_pandas_series(right_on), - left_index=left_index, - right_index=right_index, - suffixes=suffixes, - indicator=indicator, - ) - - return self.__constructor__( - query_compiler=self._query_compiler.merge( - right._query_compiler, - how=how, - on=on, - left_on=replace_external_data_keys_with_query_compiler(self, left_on), - right_on=replace_external_data_keys_with_query_compiler( - right, right_on - ), - left_index=left_index, - right_index=right_index, - sort=sort, - suffixes=suffixes, - copy=copy, - indicator=indicator, - validate=validate, - ) - ) - - def mod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `mod`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def mul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - multiply = mul - - def rmul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rmul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get not equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ne`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ne", other, axis=axis, level=level) - - def nlargest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in descending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nlargest(n, columns, keep) - ) - - def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in ascending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nsmallest( - n=n, columns=columns, keep=keep - ) - ) - - def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, - ): - """ - Pivot a level of the (necessarily hierarchical) index labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This ensures that non-pandas MultiIndex objects are caught. - nlevels = self._query_compiler.nlevels() - is_multiindex = nlevels > 1 - - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - - def pivot( - self, - *, - columns: Any, - index: Any | NoDefault = no_default, - values: Any | NoDefault = no_default, - ): - """ - Return reshaped DataFrame organized by given index / column values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if index is no_default: - index = None # pragma: no cover - if values is no_default: - values = None - - # if values is not specified, it should be the remaining columns not in - # index or columns - if values is None: - values = list(self.columns) - if index is not None: - values = [v for v in values if v not in index] - if columns is not None: - values = [v for v in values if v not in columns] - - return self.__constructor__( - query_compiler=self._query_compiler.pivot( - index=index, columns=columns, values=values - ) - ) - - def pivot_table( - self, - values=None, - index=None, - columns=None, - aggfunc="mean", - fill_value=None, - margins=False, - dropna=True, - margins_name="All", - observed=False, - sort=True, - ): - """ - Create a spreadsheet-style pivot table as a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - result = self.__constructor__( - query_compiler=self._query_compiler.pivot_table( - index=index, - values=values, - columns=columns, - aggfunc=aggfunc, - fill_value=fill_value, - margins=margins, - dropna=dropna, - margins_name=margins_name, - observed=observed, - sort=sort, - ) - ) - return result - - @dataframe_not_implemented() - @property - def plot( - self, - x=None, - y=None, - kind="line", - ax=None, - subplots=False, - sharex=None, - sharey=False, - layout=None, - figsize=None, - use_index=True, - title=None, - grid=None, - legend=True, - style=None, - logx=False, - logy=False, - loglog=False, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - rot=None, - fontsize=None, - colormap=None, - table=False, - yerr=None, - xerr=None, - secondary_y=False, - sort_columns=False, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make plots of ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().plot - - def pow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "pow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @dataframe_not_implemented() - def prod( - self, - axis=None, - skipna=True, - numeric_only=False, - min_count=0, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Return the product of the values over the requested axis. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - validate_bool_kwarg(skipna, "skipna", none_allowed=False) - axis = self._get_axis_number(axis) - axis_to_apply = self.columns if axis else self.index - if ( - skipna is not False - and numeric_only is None - and min_count > len(axis_to_apply) - ): - new_index = self.columns if not axis else self.index - return Series( - [np.nan] * len(new_index), index=new_index, dtype=np.dtype("object") - ) - - data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True) - if min_count > 1: - return data._reduce_dimension( - data._query_compiler.prod_min_count( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - return data._reduce_dimension( - data._query_compiler.prod( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - - product = prod - - def quantile( - self, - q: Scalar | ListLike = 0.5, - axis: Axis = 0, - numeric_only: bool = False, - interpolation: Literal[ - "linear", "lower", "higher", "midpoint", "nearest" - ] = "linear", - method: Literal["single", "table"] = "single", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().quantile( - q=q, - axis=axis, - numeric_only=numeric_only, - interpolation=interpolation, - method=method, - ) - - @dataframe_not_implemented() - def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Query the columns of a ``DataFrame`` with a boolean expression. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_var_dicts_in_kwargs(expr, kwargs) - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.query(expr, **kwargs) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rename( - self, - mapper: Renamer | None = None, - *, - index: Renamer | None = None, - columns: Renamer | None = None, - axis: Axis | None = None, - copy: bool | None = None, - inplace: bool = False, - level: Level | None = None, - errors: IgnoreRaise = "ignore", - ) -> DataFrame | None: - """ - Alter axes labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if mapper is None and index is None and columns is None: - raise TypeError("must pass an index to rename") - - if index is not None or columns is not None: - if axis is not None: - raise TypeError( - "Cannot specify both 'axis' and any of 'index' or 'columns'" - ) - elif mapper is not None: - raise TypeError( - "Cannot specify both 'mapper' and any of 'index' or 'columns'" - ) - else: - # use the mapper argument - if axis and self._get_axis_number(axis) == 1: - columns = mapper - else: - index = mapper - - if copy is not None: - WarningMessage.ignored_argument( - operation="dataframe.rename", - argument="copy", - message="copy parameter has been ignored with Snowflake execution engine", - ) - - if isinstance(index, dict): - index = Series(index) - - new_qc = self._query_compiler.rename( - index_renamer=index, columns_renamer=columns, level=level, errors=errors - ) - return self._create_or_update_from_compiler( - new_query_compiler=new_qc, inplace=inplace - ) - - def reindex( - self, - labels=None, - index=None, - columns=None, - axis=None, - method=None, - copy=None, - level=None, - fill_value=np.nan, - limit=None, - tolerance=None, - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - axis = self._get_axis_number(axis) - if axis == 0 and labels is not None: - index = labels - elif labels is not None: - columns = labels - return super().reindex( - index=index, - columns=columns, - method=method, - copy=copy, - level=level, - fill_value=fill_value, - limit=limit, - tolerance=tolerance, - ) - - @dataframe_not_implemented() - def reindex_like( - self, - other, - method=None, - copy: bool | None = None, - limit=None, - tolerance=None, - ) -> DataFrame: # pragma: no cover - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy is None: - copy = True - # docs say "Same as calling .reindex(index=other.index, columns=other.columns,...).": - # https://pandas.pydata.org/pandas-docs/version/1.4/reference/api/pandas.DataFrame.reindex_like.html - return self.reindex( - index=other.index, - columns=other.columns, - method=method, - copy=copy, - limit=limit, - tolerance=tolerance, - ) - - def replace( - self, - to_replace=None, - value=no_default, - inplace: bool = False, - limit=None, - regex: bool = False, - method: str | NoDefault = no_default, - ): - """ - Replace values given in `to_replace` with `value`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.replace( - to_replace=to_replace, - value=value, - limit=limit, - regex=regex, - method=method, - ) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rfloordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `rfloordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rfloordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def radd( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `radd`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "radd", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rmod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `rmod`). - """ - return self._binary_op( - "rmod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 - return super().round(decimals, args=args, **kwargs) - - def rpow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rpow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rsub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `rsub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rsub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rtruediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `rtruediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rtruediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - rdiv = rtruediv - - def select_dtypes( - self, - include: ListLike | str | type | None = None, - exclude: ListLike | str | type | None = None, - ) -> DataFrame: - """ - Return a subset of the ``DataFrame``'s columns based on the column dtypes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This line defers argument validation to pandas, which will raise errors on our behalf in cases - # like if `include` and `exclude` are None, the same type is specified in both lists, or a string - # dtype (as opposed to object) is specified. - pandas.DataFrame().select_dtypes(include, exclude) - - if include and not is_list_like(include): - include = [include] - elif include is None: - include = [] - if exclude and not is_list_like(exclude): - exclude = [exclude] - elif exclude is None: - exclude = [] - - sel = tuple(map(set, (include, exclude))) - - # The width of the np.int_/float_ alias differs between Windows and other platforms, so - # we need to include a workaround. - # https://github.com/numpy/numpy/issues/9464 - # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 - def check_sized_number_infer_dtypes(dtype): - if (isinstance(dtype, str) and dtype == "int") or (dtype is int): - return [np.int32, np.int64] - elif dtype == "float" or dtype is float: - return [np.float64, np.float32] - else: - return [infer_dtype_from_object(dtype)] - - include, exclude = map( - lambda x: set( - itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) - ), - sel, - ) - # We need to index on column position rather than label in case of duplicates - include_these = pandas.Series(not bool(include), index=range(len(self.columns))) - exclude_these = pandas.Series(not bool(exclude), index=range(len(self.columns))) - - def is_dtype_instance_mapper(dtype): - return functools.partial(issubclass, dtype.type) - - for i, dtype in enumerate(self.dtypes): - if include: - include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) - if exclude: - exclude_these[i] = not any( - map(is_dtype_instance_mapper(dtype), exclude) - ) - - dtype_indexer = include_these & exclude_these - indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] - # We need to use iloc instead of drop in case of duplicate column names - return self.iloc[:, indicate] - - def shift( - self, - periods: int | Sequence[int] = 1, - freq=None, - axis: Axis = 0, - fill_value: Hashable = no_default, - suffix: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().shift(periods, freq, axis, fill_value, suffix) - - def set_index( - self, - keys: IndexLabel - | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], - drop: bool = True, - append: bool = False, - inplace: bool = False, - verify_integrity: bool = False, - ) -> None | DataFrame: - """ - Set the ``DataFrame`` index using existing columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if not isinstance(keys, list): - keys = [keys] - - # make sure key is either hashable, index, or series - label_or_series = [] - - missing = [] - columns = self.columns.tolist() - for key in keys: - raise_if_native_pandas_objects(key) - if isinstance(key, pd.Series): - label_or_series.append(key._query_compiler) - elif isinstance(key, (np.ndarray, list, Iterator)): - label_or_series.append(pd.Series(key)._query_compiler) - elif isinstance(key, (pd.Index, pandas.MultiIndex)): - label_or_series += [ - s._query_compiler for s in self._to_series_list(key) - ] - else: - if not is_hashable(key): - raise TypeError( - f'The parameter "keys" may be a column key, one-dimensional array, or a list ' - f"containing only valid column keys and one-dimensional arrays. Received column " - f"of type {type(key)}" - ) - label_or_series.append(key) - found = key in columns - if columns.count(key) > 1: - raise ValueError(f"The column label '{key}' is not unique") - elif not found: - missing.append(key) - - if missing: - raise KeyError(f"None of {missing} are in the columns") - - new_query_compiler = self._query_compiler.set_index( - label_or_series, drop=drop, append=append - ) - - # TODO: SNOW-782633 improve this code once duplicate is supported - # this needs to pull all index which is inefficient - if verify_integrity and not new_query_compiler.index.is_unique: - duplicates = new_query_compiler.index[ - new_query_compiler.index.to_pandas().duplicated() - ].unique() - raise ValueError(f"Index has duplicate keys: {duplicates}") - - return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) - - sparse = CachedAccessor("sparse", SparseFrameAccessor) - - def squeeze(self, axis: Axis | None = None): - """ - Squeeze 1 dimensional axis objects into scalars. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) if axis is not None else None - len_columns = self._query_compiler.get_axis_len(1) - if axis == 1 and len_columns == 1: - return Series(query_compiler=self._query_compiler) - if axis in [0, None]: - # get_axis_len(0) results in a sql query to count number of rows in current - # dataframe. We should only compute len_index if axis is 0 or None. - len_index = len(self) - if axis is None and (len_columns == 1 or len_index == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 0 and len_index == 1: - return Series(query_compiler=self.T._query_compiler) - return self.copy() - - def stack( - self, - level: int | str | list = -1, - dropna: bool | NoDefault = no_default, - sort: bool | NoDefault = no_default, - future_stack: bool = False, # ignored - ): - """ - Stack the prescribed level(s) from columns to index. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if future_stack is not False: - WarningMessage.ignored_argument( # pragma: no cover - operation="DataFrame.stack", - argument="future_stack", - message="future_stack parameter has been ignored with Snowflake execution engine", - ) - if dropna is NoDefault: - dropna = True # pragma: no cover - if sort is NoDefault: - sort = True # pragma: no cover - - # This ensures that non-pandas MultiIndex objects are caught. - is_multiindex = len(self.columns.names) > 1 - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - - def sub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `sub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "sub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - subtract = sub - - @dataframe_not_implemented() - def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to the binary Feather format. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas(pandas.DataFrame.to_feather, path, **kwargs) - - @dataframe_not_implemented() - def to_gbq( - self, - destination_table, - project_id=None, - chunksize=None, - reauth=False, - if_exists="fail", - auth_local_webserver=True, - table_schema=None, - location=None, - progress_bar=True, - credentials=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to a Google BigQuery table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functionsf - return self._default_to_pandas( - pandas.DataFrame.to_gbq, - destination_table, - project_id=project_id, - chunksize=chunksize, - reauth=reauth, - if_exists=if_exists, - auth_local_webserver=auth_local_webserver, - table_schema=table_schema, - location=location, - progress_bar=progress_bar, - credentials=credentials, - ) - - @dataframe_not_implemented() - def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_orc, - path=path, - engine=engine, - index=index, - engine_kwargs=engine_kwargs, - ) - - @dataframe_not_implemented() - def to_html( - self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - bold_rows=True, - classes=None, - escape=True, - notebook=False, - border=None, - table_id=None, - render_links=False, - encoding=None, - ): # noqa: PR01, RT01, D200 - """ - Render a ``DataFrame`` as an HTML table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_html, - buf=buf, - columns=columns, - col_space=col_space, - header=header, - index=index, - na_rep=na_rep, - formatters=formatters, - float_format=float_format, - sparsify=sparsify, - index_names=index_names, - justify=justify, - max_rows=max_rows, - max_cols=max_cols, - show_dimensions=show_dimensions, - decimal=decimal, - bold_rows=bold_rows, - classes=classes, - escape=escape, - notebook=notebook, - border=border, - table_id=table_id, - render_links=render_links, - encoding=None, - ) - - @dataframe_not_implemented() - def to_parquet( - self, - path=None, - engine="auto", - compression="snappy", - index=None, - partition_cols=None, - storage_options: StorageOptions = None, - **kwargs, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - from snowflake.snowpark.modin.pandas.dispatching.factories.dispatcher import ( - FactoryDispatcher, - ) - - return FactoryDispatcher.to_parquet( - self._query_compiler, - path=path, - engine=engine, - compression=compression, - index=index, - partition_cols=partition_cols, - storage_options=storage_options, - **kwargs, - ) - - @dataframe_not_implemented() - def to_period( - self, freq=None, axis=0, copy=True - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` from ``DatetimeIndex`` to ``PeriodIndex``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_period(freq=freq, axis=axis, copy=copy) - - @dataframe_not_implemented() - def to_records( - self, index=True, column_dtypes=None, index_dtypes=None - ): # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` to a NumPy record array. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_records, - index=index, - column_dtypes=column_dtypes, - index_dtypes=index_dtypes, - ) - - @dataframe_not_implemented() - def to_stata( - self, - path: FilePath | WriteBuffer[bytes], - convert_dates: dict[Hashable, str] | None = None, - write_index: bool = True, - byteorder: str | None = None, - time_stamp: datetime.datetime | None = None, - data_label: str | None = None, - variable_labels: dict[Hashable, str] | None = None, - version: int | None = 114, - convert_strl: Sequence[Hashable] | None = None, - compression: CompressionOptions = "infer", - storage_options: StorageOptions = None, - *, - value_labels: dict[Hashable, dict[float | int, str]] | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_stata, - path, - convert_dates=convert_dates, - write_index=write_index, - byteorder=byteorder, - time_stamp=time_stamp, - data_label=data_label, - variable_labels=variable_labels, - version=version, - convert_strl=convert_strl, - compression=compression, - storage_options=storage_options, - value_labels=value_labels, - ) - - @dataframe_not_implemented() - def to_xml( - self, - path_or_buffer=None, - index=True, - root_name="data", - row_name="row", - na_rep=None, - attr_cols=None, - elem_cols=None, - namespaces=None, - prefix=None, - encoding="utf-8", - xml_declaration=True, - pretty_print=True, - parser="lxml", - stylesheet=None, - compression="infer", - storage_options=None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.default_to_pandas( - pandas.DataFrame.to_xml, - path_or_buffer=path_or_buffer, - index=index, - root_name=root_name, - row_name=row_name, - na_rep=na_rep, - attr_cols=attr_cols, - elem_cols=elem_cols, - namespaces=namespaces, - prefix=prefix, - encoding=encoding, - xml_declaration=xml_declaration, - pretty_print=pretty_print, - parser=parser, - stylesheet=stylesheet, - compression=compression, - storage_options=storage_options, - ) - ) - - def to_dict( - self, - orient: Literal[ - "dict", "list", "series", "split", "tight", "records", "index" - ] = "dict", - into: type[dict] = dict, - ) -> dict | list[dict]: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().to_dict(orient=orient, into=into) - - def to_timestamp( - self, freq=None, how="start", axis=0, copy=True - ): # noqa: PR01, RT01, D200 - """ - Cast to DatetimeIndex of timestamps, at *beginning* of period. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_timestamp(freq=freq, how=how, axis=axis, copy=copy) - - def truediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `truediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "truediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - div = divide = truediv - - def update( - self, other, join="left", overwrite=True, filter_func=None, errors="ignore" - ): # noqa: PR01, RT01, D200 - """ - Modify in place using non-NA values from another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - other = self.__constructor__(other) - query_compiler = self._query_compiler.df_update( - other._query_compiler, - join=join, - overwrite=overwrite, - filter_func=filter_func, - errors=errors, - ) - self._update_inplace(new_query_compiler=query_compiler) - - def diff( - self, - periods: int = 1, - axis: Axis = 0, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().diff( - periods=periods, - axis=axis, - ) - - def drop( - self, - labels: IndexLabel = None, - axis: Axis = 0, - index: IndexLabel = None, - columns: IndexLabel = None, - level: Level = None, - inplace: bool = False, - errors: IgnoreRaise = "raise", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().drop( - labels=labels, - axis=axis, - index=index, - columns=columns, - level=level, - inplace=inplace, - errors=errors, - ) - - def value_counts( - self, - subset: Sequence[Hashable] | None = None, - normalize: bool = False, - sort: bool = True, - ascending: bool = False, - dropna: bool = True, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series( - query_compiler=self._query_compiler.value_counts( - subset=subset, - normalize=normalize, - sort=sort, - ascending=ascending, - dropna=dropna, - ), - name="proportion" if normalize else "count", - ) - - def mask( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.mask requires an axis parameter (0 or 1) when given a Series" - ) - - return super().mask( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - def where( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - """ - Replace values where the condition is False. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.where requires an axis parameter (0 or 1) when given a Series" - ) - - return super().where( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - @dataframe_not_implemented() - def xs(self, key, axis=0, level=None, drop_level=True): # noqa: PR01, RT01, D200 - """ - Return cross-section from the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.xs, key, axis=axis, level=level, drop_level=drop_level - ) - - def set_axis( - self, - labels: IndexLabel, - *, - axis: Axis = 0, - copy: bool | NoDefault = no_default, # ignored - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not is_scalar(axis): - raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") - return super().set_axis( - labels=labels, - # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. - axis=pandas.DataFrame._get_axis_name(axis), - copy=copy, - ) - - def __getattr__(self, key): - """ - Return item identified by `key`. - - Parameters - ---------- - key : hashable - Key to get. - - Returns - ------- - Any - - Notes - ----- - First try to use `__getattribute__` method. If it fails - try to get `key` from ``DataFrame`` fields. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - try: - return object.__getattribute__(self, key) - except AttributeError as err: - if key not in _ATTRS_NO_LOOKUP and key in self.columns: - return self[key] - raise err - - def __setattr__(self, key, value): - """ - Set attribute `value` identified by `key`. - - Parameters - ---------- - key : hashable - Key to set. - value : Any - Value to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # While we let users assign to a column labeled "x" with "df.x" , there - # are some attributes that we should assume are NOT column names and - # therefore should follow the default Python object assignment - # behavior. These are: - # - anything in self.__dict__. This includes any attributes that the - # user has added to the dataframe with, e.g., `df.c = 3`, and - # any attribute that Modin has added to the frame, e.g. - # `_query_compiler` and `_siblings` - # - `_query_compiler`, which Modin initializes before it appears in - # __dict__ - # - `_siblings`, which Modin initializes before it appears in __dict__ - # - `_cache`, which pandas.cache_readonly uses to cache properties - # before it appears in __dict__. - if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__: - pass - elif key in self and key not in dir(self): - self.__setitem__(key, value) - # Note: return immediately so we don't keep this `key` as dataframe state. - # `__getattr__` will return the columns not present in `dir(self)`, so we do not need - # to manually track this state in the `dir`. - return - elif is_list_like(value) and key not in ["index", "columns"]: - WarningMessage.single_warning( - SET_DATAFRAME_ATTRIBUTE_WARNING - ) # pragma: no cover - object.__setattr__(self, key, value) - - def __setitem__(self, key: Any, value: Any): - """ - Set attribute `value` identified by `key`. - - Args: - key: Key to set - value: Value to set - - Note: - In the case where value is any list like or array, pandas checks the array length against the number of rows - of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw - a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if - the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use - enlargement filling with the last value in the array. - - Returns: - None - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - key = apply_if_callable(key, self) - if isinstance(key, DataFrame) or ( - isinstance(key, np.ndarray) and len(key.shape) == 2 - ): - # This case uses mask's codepath to perform the set, but - # we need to duplicate the code here since we are passing - # an additional kwarg `cond_fillna_with_true` to the QC here. - # We need this additional kwarg, since if df.shape - # and key.shape do not align (i.e. df has more rows), - # mask's codepath would mask the additional rows in df - # while for setitem, we need to keep the original values. - if not isinstance(key, DataFrame): - if key.dtype != bool: - raise TypeError( - "Must pass DataFrame or 2-d ndarray with boolean values only" - ) - key = DataFrame(key) - key._query_compiler._shape_hint = "array" - - if value is not None: - value = apply_if_callable(value, self) - - if isinstance(value, np.ndarray): - value = DataFrame(value) - value._query_compiler._shape_hint = "array" - elif isinstance(value, pd.Series): - # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this - # error instead, since it is more descriptive. - raise ValueError( - "setitem with a 2D key does not support Series values." - ) - - if isinstance(value, BasePandasDataset): - value = value._query_compiler - - query_compiler = self._query_compiler.mask( - cond=key._query_compiler, - other=value, - axis=None, - level=None, - cond_fillna_with_true=True, - ) - - return self._create_or_update_from_compiler(query_compiler, inplace=True) - - # Error Checking: - if (isinstance(key, pd.Series) or is_list_like(key)) and ( - isinstance(value, range) - ): - raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) - elif isinstance(value, slice): - # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. - raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) - - # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column - # key. - index, columns = slice(None), key - index_is_bool_indexer = False - if isinstance(key, slice): - if is_integer(key.start) and is_integer(key.stop): - # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as - # df.iloc[1:2, :] = val - self.iloc[key] = value - return - index, columns = key, slice(None) - elif isinstance(key, pd.Series): - if is_bool_dtype(key.dtype): - index, columns = key, slice(None) - index_is_bool_indexer = True - elif is_bool_indexer(key): - index, columns = pd.Series(key), slice(None) - index_is_bool_indexer = True - - # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case - # we have to explicitly set matching_item_columns_by_label to False for setitem. - index = index._query_compiler if isinstance(index, BasePandasDataset) else index - columns = ( - columns._query_compiler - if isinstance(columns, BasePandasDataset) - else columns - ) - from .indexing import is_2d_array - - matching_item_rows_by_label = not is_2d_array(value) - if is_2d_array(value): - value = DataFrame(value) - item = value._query_compiler if isinstance(value, BasePandasDataset) else value - new_qc = self._query_compiler.set_2d_labels( - index, - columns, - item, - # setitem always matches item by position - matching_item_columns_by_label=False, - matching_item_rows_by_label=matching_item_rows_by_label, - index_is_bool_indexer=index_is_bool_indexer, - # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling - # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the - # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have - # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns - # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", - # "X", "X". - deduplicate_columns=True, - ) - return self._update_inplace(new_query_compiler=new_qc) - - def abs(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().abs() - - def __and__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__and__", other, axis=1) - - def __rand__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__rand__", other, axis=1) - - def __or__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__or__", other, axis=1) - - def __ror__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__ror__", other, axis=1) - - def __neg__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().__neg__() - - def __iter__(self): - """ - Iterate over info axis. - - Returns - ------- - iterable - Iterator of the columns names. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return iter(self.columns) - - def __contains__(self, key): - """ - Check if `key` in the ``DataFrame.columns``. - - Parameters - ---------- - key : hashable - Key to check the presence in the columns. - - Returns - ------- - bool - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns.__contains__(key) - - def __round__(self, decimals=0): - """ - Round each value in a ``DataFrame`` to the given number of decimals. - - Parameters - ---------- - decimals : int, default: 0 - Number of decimal places to round to. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().round(decimals) - - @dataframe_not_implemented() - def __delitem__(self, key): - """ - Delete item identified by `key` label. - - Parameters - ---------- - key : hashable - Key to delete. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if key not in self: - raise KeyError(key) - self._update_inplace(new_query_compiler=self._query_compiler.delitem(key)) - - __add__ = add - __iadd__ = add # pragma: no cover - __radd__ = radd - __mul__ = mul - __imul__ = mul # pragma: no cover - __rmul__ = rmul - __pow__ = pow - __ipow__ = pow # pragma: no cover - __rpow__ = rpow - __sub__ = sub - __isub__ = sub # pragma: no cover - __rsub__ = rsub - __floordiv__ = floordiv - __ifloordiv__ = floordiv # pragma: no cover - __rfloordiv__ = rfloordiv - __truediv__ = truediv - __itruediv__ = truediv # pragma: no cover - __rtruediv__ = rtruediv - __mod__ = mod - __imod__ = mod # pragma: no cover - __rmod__ = rmod - __rdiv__ = rdiv - - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): - """ - Get a Modin DataFrame that implements the dataframe exchange protocol. - - See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. - - Parameters - ---------- - nan_as_null : bool, default: False - A keyword intended for the consumer to tell the producer - to overwrite null values in the data with ``NaN`` (or ``NaT``). - This currently has no effect; once support for nullable extension - dtypes is added, this value should be propagated to columns. - allow_copy : bool, default: True - A keyword that defines whether or not the library is allowed - to make a copy of the data. For example, copying data would be necessary - if a library supports strided buffers, given that this protocol - specifies contiguous buffers. Currently, if the flag is set to ``False`` - and a copy is needed, a ``RuntimeError`` will be raised. - - Returns - ------- - ProtocolDataframe - A dataframe object following the dataframe protocol specification. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - ErrorMessage.not_implemented( - "Snowpark pandas does not support the DataFrame interchange " - + "protocol method `__dataframe__`. To use Snowpark pandas " - + "DataFrames with third-party libraries that try to call the " - + "`__dataframe__` method, please convert this Snowpark pandas " - + "DataFrame to pandas with `to_pandas()`." - ) - - return self._query_compiler.to_dataframe( - nan_as_null=nan_as_null, allow_copy=allow_copy - ) - - @dataframe_not_implemented() - @property - def attrs(self): # noqa: RT01, D200 - """ - Return dictionary of global attributes of this dataset. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def attrs(df): - return df.attrs - - return self._default_to_pandas(attrs) - - @dataframe_not_implemented() - @property - def style(self): # noqa: RT01, D200 - """ - Return a Styler object. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def style(df): - """Define __name__ attr because properties do not have it.""" - return df.style - - return self._default_to_pandas(style) - - def isin( - self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(values, dict): - return super().isin(values) - elif isinstance(values, Series): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not values.index.is_unique: - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - elif isinstance(values, DataFrame): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not (values.columns.is_unique and values.index.is_unique): - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - else: - if not is_list_like(values): - # throw pandas compatible error - raise TypeError( - "only list-like or dict-like objects are allowed " - f"to be passed to {self.__class__.__name__}.isin(), " - f"you passed a '{type(values).__name__}'" - ) - return super().isin(values) - - def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a ``DataFrame`` with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - DataFrame or None - None if update was done, ``DataFrame`` otherwise. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace: - return self.__constructor__(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - def _get_numeric_data(self, axis: int): - """ - Grab only numeric data from ``DataFrame``. - - Parameters - ---------- - axis : {0, 1} - Axis to inspect on having numeric types only. - - Returns - ------- - DataFrame - ``DataFrame`` with numeric data. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas ignores `numeric_only` if `axis` is 1, but we do have to drop - # non-numeric columns if `axis` is 0. - if axis != 0: - return self - return self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - - def _validate_dtypes(self, numeric_only=False): - """ - Check that all the dtypes are the same. - - Parameters - ---------- - numeric_only : bool, default: False - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - dtype = self.dtypes[0] - for t in self.dtypes: - if numeric_only and not is_numeric_dtype(t): - raise TypeError(f"{t} is not a numeric data type") - elif not numeric_only and t != dtype: - raise TypeError(f"Cannot compare type '{t}' with type '{dtype}'") - - def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False): - """ - Validate data dtype for `sum`, `prod` and `mean` methods. - - Parameters - ---------- - axis : {0, 1} - Axis to validate over. - numeric_only : bool - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - ignore_axis : bool, default: False - Whether or not to ignore `axis` parameter. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # We cannot add datetime types, so if we are summing a column with - # dtype datetime64 and cannot ignore non-numeric types, we must throw a - # TypeError. - if ( - not axis - and numeric_only is False - and any(dtype == np.dtype("datetime64[ns]") for dtype in self.dtypes) - ): - raise TypeError("Cannot add Timestamp Types") - - # If our DataFrame has both numeric and non-numeric dtypes then - # operations between these types do not make sense and we must raise a - # TypeError. The exception to this rule is when there are datetime and - # timedelta objects, in which case we proceed with the comparison - # without ignoring any non-numeric types. We must check explicitly if - # numeric_only is False because if it is None, it will default to True - # if the operation fails with mixed dtypes. - if ( - (axis or ignore_axis) - and numeric_only is False - and np.unique([is_numeric_dtype(dtype) for dtype in self.dtypes]).size == 2 - ): - # check if there are columns with dtypes datetime or timedelta - if all( - dtype != np.dtype("datetime64[ns]") - and dtype != np.dtype("timedelta64[ns]") - for dtype in self.dtypes - ): - raise TypeError("Cannot operate on Numeric and Non-Numeric Types") - - return self._get_numeric_data(axis) if numeric_only else self - - def _to_pandas( - self, - *, - statement_params: dict[str, str] | None = None, - **kwargs: Any, - ) -> pandas.DataFrame: - """ - Convert Snowpark pandas DataFrame to pandas DataFrame - - Args: - statement_params: Dictionary of statement level parameters to be set while executing this action. - - Returns: - pandas DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.to_pandas( - statement_params=statement_params, **kwargs - ) - - def _validate_eval_query(self, expr, **kwargs): - """ - Validate the arguments of ``eval`` and ``query`` functions. - - Parameters - ---------- - expr : str - The expression to evaluate. This string cannot contain any - Python statements, only Python expressions. - **kwargs : dict - Optional arguments of ``eval`` and ``query`` functions. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(expr, str) and expr == "": - raise ValueError("expr cannot be an empty string") - - if isinstance(expr, str) and "not" in expr: - if "parser" in kwargs and kwargs["parser"] == "python": - ErrorMessage.not_implemented( # pragma: no cover - "Snowpark pandas does not yet support 'not' in the " - + "expression for the methods `DataFrame.eval` or " - + "`DataFrame.query`" - ) - - def _reduce_dimension(self, query_compiler): - """ - Reduce the dimension of data from the `query_compiler`. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to retrieve the data. - - Returns - ------- - Series - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series(query_compiler=query_compiler) - - def _set_axis_name(self, name, axis=0, inplace=False): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - renamed = self if inplace else self.copy() - if axis == 0: - renamed.index = renamed.index.set_names(name) - else: - renamed.columns = renamed.columns.set_names(name) - if not inplace: - return renamed - - def _to_datetime(self, **kwargs): - """ - Convert `self` to datetime. - - Parameters - ---------- - **kwargs : dict - Optional arguments to use during query compiler's - `to_datetime` invocation. - - Returns - ------- - Series of datetime64 dtype - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._reduce_dimension( - query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) - ) - - # Persistance support methods - BEGIN - @classmethod - def _inflate_light(cls, query_compiler): - """ - Re-creates the object from previously-serialized lightweight representation. - - The method is used for faster but not disk-storable persistence. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `query_compiler`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(query_compiler=query_compiler) - - @classmethod - def _inflate_full(cls, pandas_df): - """ - Re-creates the object from previously-serialized disk-storable representation. - - Parameters - ---------- - pandas_df : pandas.DataFrame - Data to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `pandas_df`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(data=from_pandas(pandas_df)) - - @dataframe_not_implemented() - def __reduce__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._query_compiler.finalize() - # if PersistentPickle.get(): - # return self._inflate_full, (self._to_pandas(),) - return self._inflate_light, (self._query_compiler,) - - # Persistance support methods - END diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index 2ca9d8e5b83..5024d0618ac 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -31,7 +31,7 @@ import numpy as np import pandas import pandas.core.common as common -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import IntervalIndex, NaT, Timedelta, Timestamp from pandas._libs import NaTType, lib @@ -65,7 +65,6 @@ # add this line to make doctests runnable from snowflake.snowpark.modin import pandas as pd # noqa: F401 -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import ( is_scalar, raise_if_native_pandas_objects, @@ -92,10 +91,9 @@ # linking to `snowflake.snowpark.DataFrame`, we need to explicitly # qualify return types in this file with `modin.pandas.DataFrame`. # SNOW-1233342: investigate how to fix these links without using absolute paths + import modin from modin.core.storage_formats import BaseQueryCompiler # pragma: no cover - import snowflake # pragma: no cover - _logger = getLogger(__name__) VALID_DATE_TYPE = Union[ @@ -137,8 +135,8 @@ def notna(obj): # noqa: PR01, RT01, D200 @snowpark_pandas_telemetry_standalone_function_decorator def merge( - left: snowflake.snowpark.modin.pandas.DataFrame | Series, - right: snowflake.snowpark.modin.pandas.DataFrame | Series, + left: modin.pandas.DataFrame | Series, + right: modin.pandas.DataFrame | Series, how: str | None = "inner", on: IndexLabel | None = None, left_on: None @@ -414,7 +412,7 @@ def merge_asof( tolerance: int | Timedelta | None = None, allow_exact_matches: bool = True, direction: str = "backward", -) -> snowflake.snowpark.modin.pandas.DataFrame: +) -> modin.pandas.DataFrame: """ Perform a merge by key distance. @@ -1105,8 +1103,8 @@ def value_counts( @snowpark_pandas_telemetry_standalone_function_decorator def concat( objs: ( - Iterable[snowflake.snowpark.modin.pandas.DataFrame | Series] - | Mapping[Hashable, snowflake.snowpark.modin.pandas.DataFrame | Series] + Iterable[modin.pandas.DataFrame | Series] + | Mapping[Hashable, modin.pandas.DataFrame | Series] ), axis: Axis = 0, join: str = "outer", @@ -1117,7 +1115,7 @@ def concat( verify_integrity: bool = False, sort: bool = False, copy: bool = True, -) -> snowflake.snowpark.modin.pandas.DataFrame | Series: +) -> modin.pandas.DataFrame | Series: """ Concatenate pandas objects along a particular axis. @@ -1490,7 +1488,7 @@ def concat( def to_datetime( arg: DatetimeScalarOrArrayConvertible | DictConvertible - | snowflake.snowpark.modin.pandas.DataFrame + | modin.pandas.DataFrame | Series, errors: DateTimeErrorChoices = "raise", dayfirst: bool = False, diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index c672f04da63..5da10d9b7a6 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -45,6 +45,7 @@ import pandas from modin.pandas import Series from modin.pandas.base import BasePandasDataset +from modin.pandas.dataframe import DataFrame from pandas._libs.tslibs import Resolution, parsing from pandas._typing import AnyArrayLike, Scalar from pandas.api.types import is_bool, is_list_like @@ -61,7 +62,6 @@ import snowflake.snowpark.modin.pandas as pd import snowflake.snowpark.modin.pandas.utils as frontend_utils -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin._internal.indexing_utils import ( MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE, diff --git a/src/snowflake/snowpark/modin/pandas/io.py b/src/snowflake/snowpark/modin/pandas/io.py index 25959212a18..b92e8ee3582 100644 --- a/src/snowflake/snowpark/modin/pandas/io.py +++ b/src/snowflake/snowpark/modin/pandas/io.py @@ -92,7 +92,7 @@ # below logic is to handle circular imports without errors if TYPE_CHECKING: # pragma: no cover - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available @@ -106,7 +106,7 @@ class ModinObjects: def DataFrame(cls): """Get ``modin.pandas.DataFrame`` class.""" if cls._dataframe is None: - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame cls._dataframe = DataFrame return cls._dataframe diff --git a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py index 3529355b81b..ee782f3cdf3 100644 --- a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py +++ b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py @@ -5,10 +5,9 @@ from collections.abc import Iterator from typing import Any, Callable +import modin.pandas.dataframe as DataFrame import pandas -import snowflake.snowpark.modin.pandas.dataframe as DataFrame - PARTITION_SIZE = 4096 diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index 3986e3d52a9..a48f16992d4 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -78,7 +78,7 @@ def from_non_pandas(df, index, columns, dtype): new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype) if new_qc is not None: - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=new_qc) return new_qc @@ -99,7 +99,7 @@ def from_pandas(df): A new Modin DataFrame object. """ # from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) @@ -118,10 +118,11 @@ def from_arrow(at): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_arrow(at)) @@ -142,10 +143,11 @@ def from_dataframe(df): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df)) @@ -226,7 +228,7 @@ def from_modin_frame_to_mi(df, sortorder=None, names=None): pandas.MultiIndex The pandas.MultiIndex representation of the given DataFrame. """ - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame if isinstance(df, DataFrame): df = df._to_pandas() diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index d3ac525572a..eceb9ca7d7f 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -69,6 +69,7 @@ inherit_modules = [ (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset), + (docstrings.dataframe.DataFrame, modin.pandas.dataframe.DataFrame), (docstrings.series.Series, modin.pandas.series.Series), (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods), ( @@ -90,17 +91,3 @@ snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = ( False ) - - -# TODO: SNOW-1504302: Modin upgrade - use Snowpark pandas DataFrame for isocalendar -# OSS Modin's DatetimeProperties frontend class wraps the returned query compiler with `modin.pandas.DataFrame`. -# Since we currently replace `pd.DataFrame` with our own Snowpark pandas DataFrame object, this causes errors -# since OSS Modin explicitly imports its own DataFrame class here. This override can be removed once the frontend -# DataFrame class is removed from our codebase. -def isocalendar(self): # type: ignore - from snowflake.snowpark.modin.pandas import DataFrame - - return DataFrame(query_compiler=self._query_compiler.dt_isocalendar()) - - -modin.pandas.series_utils.DatetimeProperties.isocalendar = isocalendar diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 01ccad8f430..0005df924db 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Iterable from functools import partial from inspect import getmembers -from types import BuiltinFunctionType +from types import BuiltinFunctionType, MappingProxyType from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union import numpy as np @@ -56,6 +56,7 @@ stddev, stddev_pop, sum as sum_, + trunc, var_pop, variance, when, @@ -65,6 +66,9 @@ OrderedDataFrame, OrderingColumn, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.utils import ( from_pandas_label, pandas_lit, @@ -85,7 +89,7 @@ } -def array_agg_keepna( +def _array_agg_keepna( column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn] ) -> Column: """ @@ -239,62 +243,63 @@ def _columns_coalescing_idxmax_idxmin_helper( ) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding snowflake builtin aggregation function for axis=0. If any change -# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and -# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly. -SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": count, - "mean": mean, - "min": min_, - "max": max_, - "idxmax": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmax" - ), - "idxmin": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmin" - ), - "sum": sum_, - "median": median, - "skew": skew, - "std": stddev, - "var": variance, - "all": builtin("booland_agg"), - "any": builtin("boolor_agg"), - np.max: max_, - np.min: min_, - np.sum: sum_, - np.mean: mean, - np.median: median, - np.std: stddev, - np.var: variance, - "array_agg": array_agg, - "quantile": column_quantile, - "nunique": count_distinct, -} -GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = ( - "min", - "max", - "sum", - "mean", - "median", - "std", - np.max, - np.min, - np.sum, - np.mean, - np.median, - np.std, -) -GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = ( - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", -) +class _SnowparkPandasAggregation(NamedTuple): + """ + A representation of a Snowpark pandas aggregation. + + This structure gives us a common representation for an aggregation that may + have multiple aliases, like "sum" and np.sum. + """ + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool + + # This callable takes a single Snowpark column as input and aggregates the + # column on axis=0. If None, Snowpark pandas does not support this + # aggregation on axis=0. + axis_0_aggregation: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=True, i.e. not including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=True. + axis_1_aggregation_skipna: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=False, i.e. including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=False. + axis_1_aggregation_keepna: Optional[Callable] = None + + +class SnowflakeAggFunc(NamedTuple): + """ + A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType. + """ + + # The aggregation function in Snowpark. + # For aggregation on axis=0, this field should take a single Snowpark + # column and return the aggregated column. + # For aggregation on axis=1, this field should take an arbitrary number + # of Snowpark columns and return the aggregated column. + snowpark_aggregation: Callable + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool class AggFuncWithLabel(NamedTuple): @@ -413,35 +418,143 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: return sum(builtin("zeroifnull")(col) for col in cols) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation -# function may either be a builtin aggregation function, or a function taking in *arg columns -# that then calls the appropriate builtin aggregations. -SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": _columns_count, - "sum": _columns_coalescing_sum, - np.sum: _columns_coalescing_sum, - "min": _columns_coalescing_min, - "max": _columns_coalescing_max, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - np.min: _columns_coalescing_min, - np.max: _columns_coalescing_max, -} +def _create_pandas_to_snowpark_pandas_aggregation_map( + pandas_functions: Iterable[AggFuncTypeBase], + snowpark_pandas_aggregation: _SnowparkPandasAggregation, +) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]: + """ + Create a map from the given pandas functions to the given _SnowparkPandasAggregation. -# These functions are called instead if skipna=False -SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "min": least, - "max": greatest, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark - # sum_, since Snowpark sum_ gets the sum of all rows within a single column. - "sum": lambda *cols: sum(cols), - np.sum: lambda *cols: sum(cols), - np.min: least, - np.max: greatest, -} + Args; + pandas_functions: The pandas functions that map to the given aggregation. + snowpark_pandas_aggregation: The aggregation to map to + + Returns: + The map. + """ + return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions}) + + +# Map between the pandas input aggregation function (str or numpy function) and +# _SnowparkPandasAggregation representing information about applying the +# aggregation in Snowpark pandas. +_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[ + AggFuncTypeBase, _SnowparkPandasAggregation +] = MappingProxyType( + { + "count": _SnowparkPandasAggregation( + axis_0_aggregation=count, + axis_1_aggregation_skipna=_columns_count, + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("mean", np.mean), + _SnowparkPandasAggregation( + axis_0_aggregation=mean, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("min", np.min), + _SnowparkPandasAggregation( + axis_0_aggregation=min_, + axis_1_aggregation_keepna=least, + axis_1_aggregation_skipna=_columns_coalescing_min, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("max", np.max), + _SnowparkPandasAggregation( + axis_0_aggregation=max_, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=_columns_coalescing_max, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("sum", np.sum), + _SnowparkPandasAggregation( + axis_0_aggregation=sum_, + # IMPORTANT: count and sum use python builtin sum to invoke + # __add__ on each column rather than Snowpark sum_, since + # Snowpark sum_ gets the sum of all rows within a single column. + axis_1_aggregation_keepna=lambda *cols: sum(cols), + axis_1_aggregation_skipna=_columns_coalescing_sum, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("median", np.median), + _SnowparkPandasAggregation( + axis_0_aggregation=median, + preserves_snowpark_pandas_types=True, + ), + ), + "idxmax": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "idxmin": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "skew": _SnowparkPandasAggregation( + axis_0_aggregation=skew, + preserves_snowpark_pandas_types=True, + ), + "all": _SnowparkPandasAggregation( + # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("booland_agg")(col(c)), pandas_lit(True) + ), + preserves_snowpark_pandas_types=False, + ), + "any": _SnowparkPandasAggregation( + # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("boolor_agg")(col(c)), pandas_lit(False) + ), + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("std", np.std), + _SnowparkPandasAggregation( + axis_0_aggregation=stddev, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("var", np.var), + _SnowparkPandasAggregation( + axis_0_aggregation=variance, + # variance units are the square of the input column units, so + # variance does not preserve types. + preserves_snowpark_pandas_types=False, + ), + ), + "array_agg": _SnowparkPandasAggregation( + axis_0_aggregation=array_agg, + preserves_snowpark_pandas_types=False, + ), + "quantile": _SnowparkPandasAggregation( + axis_0_aggregation=column_quantile, + preserves_snowpark_pandas_types=True, + ), + "nunique": _SnowparkPandasAggregation( + axis_0_aggregation=count_distinct, + preserves_snowpark_pandas_types=False, + ), + } +) class AggregateColumnOpParameters(NamedTuple): @@ -462,7 +575,7 @@ class AggregateColumnOpParameters(NamedTuple): agg_snowflake_quoted_identifier: str # the snowflake aggregation function to apply on the column - snowflake_agg_func: Callable + snowflake_agg_func: SnowflakeAggFunc # the columns specifying the order of rows in the column. This is only # relevant for aggregations that depend on row order, e.g. summing a string @@ -471,88 +584,108 @@ class AggregateColumnOpParameters(NamedTuple): def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: - return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP + return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION def get_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0 -) -> Optional[Callable]: + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] +) -> Optional[SnowflakeAggFunc]: """ Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. If no corresponding snowflake aggregation function can be found, return None. """ - if axis == 0: - snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func) - if snowflake_agg_func == stddev or snowflake_agg_func == variance: - # for aggregation function std and var, we only support ddof = 0 or ddof = 1. - # when ddof is 1, std is mapped to stddev, var is mapped to variance - # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop - # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 - ddof = agg_kwargs.get("ddof", 1) - if ddof != 1 and ddof != 0: - return None - if ddof == 0: - return stddev_pop if snowflake_agg_func == stddev else var_pop - elif snowflake_agg_func == column_quantile: - interpolation = agg_kwargs.get("interpolation", "linear") - q = agg_kwargs.get("q", 0.5) - if interpolation not in ("linear", "nearest"): - return None - if not is_scalar(q): - # SNOW-1062878 Because list-like q would return multiple rows, calling quantile - # through the aggregate frontend in this manner is unsupported. - return None - return lambda col: column_quantile(col, interpolation, q) - elif agg_func in ("all", "any"): - # If there are no rows in the input frame, the function will also return NULL, which should - # instead by TRUE for "all" and FALSE for "any". - # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name - # as a string literal. - # The generated SQL expression for "all" is - # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE) - # The expression for "any" is - # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE) - default_value = bool(agg_func == "all") - return lambda col: builtin("ifnull")( - # mypy refuses to acknowledge snowflake_agg_func is non-NULL here - snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc] - pandas_lit(default_value), + if axis == 1: + return _generate_rowwise_aggregation_function(agg_func, agg_kwargs) + + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + + if snowpark_pandas_aggregation is None: + # We don't have any implementation at all for this aggregation. + return None + + snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation + + if snowpark_aggregation is None: + # We don't have an implementation on axis=0 for this aggregation. + return None + + # Rewrite some aggregations according to `agg_kwargs.` + if snowpark_aggregation == stddev or snowpark_aggregation == variance: + # for aggregation function std and var, we only support ddof = 0 or ddof = 1. + # when ddof is 1, std is mapped to stddev, var is mapped to variance + # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop + # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 + ddof = agg_kwargs.get("ddof", 1) + if ddof != 1 and ddof != 0: + return None + if ddof == 0: + snowpark_aggregation = ( + stddev_pop if snowpark_aggregation == stddev else var_pop ) - else: - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + elif snowpark_aggregation == column_quantile: + interpolation = agg_kwargs.get("interpolation", "linear") + q = agg_kwargs.get("q", 0.5) + if interpolation not in ("linear", "nearest"): + return None + if not is_scalar(q): + # SNOW-1062878 Because list-like q would return multiple rows, calling quantile + # through the aggregate frontend in this manner is unsupported. + return None + + def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: + return column_quantile(col, interpolation, q) - return snowflake_agg_func + assert ( + snowpark_aggregation is not None + ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." + return SnowflakeAggFunc( + snowpark_aggregation=snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def generate_rowwise_aggregation_function( +def _generate_rowwise_aggregation_function( agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any] -) -> Optional[Callable]: +) -> Optional[SnowflakeAggFunc]: """ Get a callable taking *arg columns to apply for an aggregation. Unlike get_snowflake_agg_func, this function may return a wrapped composition of Snowflake builtin functions depending on the values of the specified kwargs. """ - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) - if not agg_kwargs.get("skipna", True): - snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get( - agg_func, snowflake_agg_func - ) + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + if snowpark_pandas_aggregation is None: + return None + snowpark_aggregation = ( + snowpark_pandas_aggregation.axis_1_aggregation_skipna + if agg_kwargs.get("skipna", True) + else snowpark_pandas_aggregation.axis_1_aggregation_keepna + ) + if snowpark_aggregation is None: + return None min_count = agg_kwargs.get("min_count", 0) if min_count > 0: + original_aggregation = snowpark_aggregation + # Create a case statement to check if the number of non-null values exceeds min_count # when min_count > 0, if the number of not NULL values is < min_count, return NULL. - def agg_func_wrapper(fn: Callable) -> Callable: - return lambda *cols: when( - _columns_count(*cols) < min_count, pandas_lit(None) - ).otherwise(fn(*cols)) + def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: + return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise( + original_aggregation(*cols) + ) - return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func) - return snowflake_agg_func + return SnowflakeAggFunc( + snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def is_supported_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int +def _is_supported_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ check if the aggregation function is supported with snowflake. Current supported @@ -566,12 +699,14 @@ def is_supported_snowflake_agg_func( is_valid: bool. Whether it is valid to implement with snowflake or not. """ if isinstance(agg_func, tuple) and len(agg_func) == 2: + # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, + # take the second part of the named aggregation. agg_func = agg_func[0] return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None -def are_all_agg_funcs_supported_by_snowflake( - agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int +def _are_all_agg_funcs_supported_by_snowflake( + agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ Check if all aggregation functions in the given list are snowflake supported @@ -582,14 +717,14 @@ def are_all_agg_funcs_supported_by_snowflake( return False. """ return all( - is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs ) def check_is_aggregation_supported_in_snowflake( agg_func: AggFuncType, agg_kwargs: dict[str, Any], - axis: int, + axis: Literal[0, 1], ) -> bool: """ check if distributed implementation with snowflake is available for the aggregation @@ -608,18 +743,18 @@ def check_is_aggregation_supported_in_snowflake( if is_dict_like(agg_func): return all( ( - are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) if is_list_like(value) and not is_named_tuple(value) - else is_supported_snowflake_agg_func(value, agg_kwargs, axis) + else _is_supported_snowflake_agg_func(value, agg_kwargs, axis) ) for value in agg_func.values() ) elif is_list_like(agg_func): - return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) - return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) + return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) -def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: +def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: """ Is the given snowflake aggregation function needs to be applied on the numeric column. """ @@ -697,7 +832,7 @@ def drop_non_numeric_data_columns( ) -def generate_aggregation_column( +def _generate_aggregation_column( agg_column_op_params: AggregateColumnOpParameters, agg_kwargs: dict[str, Any], is_groupby_agg: bool, @@ -721,8 +856,14 @@ def generate_aggregation_column( SnowparkColumn after the aggregation function. The column is also aliased back to the original name """ snowpark_column = agg_column_op_params.snowflake_quoted_identifier - snowflake_agg_func = agg_column_op_params.snowflake_agg_func - if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( + snowflake_agg_func = agg_column_op_params.snowflake_agg_func.snowpark_aggregation + + if snowflake_agg_func in (variance, var_pop) and isinstance( + agg_column_op_params.data_type, TimedeltaType + ): + raise TypeError("timedelta64 type does not support var operations") + + if _is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( agg_column_op_params.data_type, BooleanType ): # if the column is a boolean column and the aggregation function requires numeric values, @@ -753,7 +894,7 @@ def generate_aggregation_column( # note that we always assume keepna for array_agg. TODO(SNOW-1040398): # make keepna treatment consistent across array_agg and other # aggregation methods. - agg_snowpark_column = array_agg_keepna( + agg_snowpark_column = _array_agg_keepna( snowpark_column, ordering_columns=agg_column_op_params.ordering_columns ) elif ( @@ -825,6 +966,19 @@ def generate_aggregation_column( ), f"No case expression is constructed with skipna({skipna}), min_count({min_count})" agg_snowpark_column = case_expr.otherwise(agg_snowpark_column) + if ( + isinstance(agg_column_op_params.data_type, TimedeltaType) + and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types + ): + # timedelta aggregations that produce timedelta results might produce + # a decimal type in snowflake, e.g. + # pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in + # Snowflake. We truncate the decimal part of the result, as pandas + # does. + agg_snowpark_column = cast( + trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type + ) + # rename the column to agg_column_quoted_identifier agg_snowpark_column = agg_snowpark_column.as_( agg_column_op_params.agg_snowflake_quoted_identifier @@ -857,7 +1011,7 @@ def aggregate_with_ordered_dataframe( is_groupby_agg = groupby_columns is not None agg_list: list[SnowparkColumn] = [ - generate_aggregation_column( + _generate_aggregation_column( agg_column_op_params=agg_col_op, agg_kwargs=agg_kwargs, is_groupby_agg=is_groupby_agg, @@ -973,7 +1127,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: ) -def generate_pandas_labels_for_agg_result_columns( +def _generate_pandas_labels_for_agg_result_columns( pandas_label: Hashable, num_levels: int, agg_func_list: list[AggFuncInfo], @@ -1102,7 +1256,7 @@ def generate_column_agg_info( ) # generate the pandas label and quoted identifier for the result aggregation columns, one # for each aggregation function to apply. - agg_col_labels = generate_pandas_labels_for_agg_result_columns( + agg_col_labels = _generate_pandas_labels_for_agg_result_columns( pandas_label_to_identifier.pandas_label, num_levels, agg_func_list, # type: ignore[arg-type] diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py index 882dc79d2a8..4eaf98d9b29 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -189,8 +189,6 @@ def compute_bin_indices( values_frame, cuts_frame, how="asof", - left_on=[], - right_on=[], left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0], right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0], match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index c2c224e404c..6207bd2399a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame( key, count_frame, "cross", - left_on=[], - right_on=[], inherit_join_index=InheritJoinIndex.FROM_LEFT, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 79f063b9ece..d07211dbcf5 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple): result_column_mapper: JoinOrAlignResultColumnMapper +def assert_snowpark_pandas_types_match( + left: InternalFrame, + right: InternalFrame, + left_join_identifiers: list[str], + right_join_identifiers: list[str], +) -> None: + """ + If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised. + + Args: + left: An internal frame to use on left side of join. + right: An internal frame to use on right side of join. + left_join_identifiers: List of snowflake identifiers to check types from 'left' frame. + right_join_identifiers: List of snowflake identifiers to check types from 'right' frame. + left_identifiers and right_identifiers must be lists of equal length. + + Returns: None + + Raises: ValueError + """ + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in left_join_identifiers + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in right_join_identifiers + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_join_identifiers[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = ( + rt + if rt is not None + else right.get_snowflake_type(right_join_identifiers[i]) + ) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + def join( left: InternalFrame, right: InternalFrame, how: JoinTypeLit, - left_on: list[str], - right_on: list[str], + left_on: Optional[list[str]] = None, + right_on: Optional[list[str]] = None, left_match_col: Optional[str] = None, right_match_col: Optional[str] = None, match_comparator: Optional[MatchComparator] = None, @@ -161,40 +206,48 @@ def join( include mapping for index + data columns, ordering columns and row position column if exists. """ - assert len(left_on) == len( - right_on - ), "left_on and right_on must be of same length or both be None" - if join_key_coalesce_config is not None: - assert len(join_key_coalesce_config) == len( - left_on - ), "join_key_coalesce_config must be of same length as left_on and right_on" assert how in get_args( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" - def assert_snowpark_pandas_types_match() -> None: - """If Snowpark pandas types do not match, then a ValueError will be raised.""" - left_types = [ - left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in left_on - ] - right_types = [ - right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in right_on - ] - for i, (lt, rt) in enumerate(zip(left_types, right_types)): - if lt != rt: - left_on_id = left_on[i] - idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) - key = left.data_column_pandas_labels[idx] - lt = lt if lt is not None else left.get_snowflake_type(left_on_id) - rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) - raise ValueError( - f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " - f"If you wish to proceed you should use pd.concat" - ) + left_on = left_on or [] + right_on = right_on or [] + assert len(left_on) == len( + right_on + ), "left_on and right_on must be of same length or both be None" - assert_snowpark_pandas_types_match() + if how == "asof": + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + left_join_key = [left_match_col] + right_join_key = [right_match_col] + left_join_key.extend(left_on) + right_join_key.extend(right_on) + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key" + else: + left_join_key = left_on + right_join_key = right_on + assert ( + left_match_col is None + and right_match_col is None + and match_comparator is None + ), f"match condition should not be provided for {how} join" + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "join_key_coalesce_config must be of same length as left_on and right_on" + + assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key) # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None: match_comparator=match_comparator, how=how, ) - return _create_internal_frame_with_join_or_align_result( joined_ordered_dataframe, left, right, how, - left_on, - right_on, + left_join_key, + right_join_key, sort, join_key_coalesce_config, inherit_join_index, @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None: ) elif self._how == "right": ordering_column_identifiers = mapped_right_on + elif self._how == "asof": + # Order only by the left match_condition column + ordering_column_identifiers = [mapped_left_on[0]] else: # left join, inner join, left align, coalesce align ordering_column_identifiers = mapped_left_on diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index f7ae87c2a5d..91537d98e30 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -1197,22 +1197,29 @@ def join( # get the new mapped right on identifier right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols] - # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' - on = None - for left_col, right_col in zip(left_on_cols, right_on_cols): - eq = Column(left_col).equal_null(Column(right_col)) - on = eq if on is None else on & eq - if how == "asof": - assert left_match_col, "left_match_col was not provided to ASOF Join" + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" left_match_col = Column(left_match_col) # Get the new mapped right match condition identifier - assert right_match_col, "right_match_col was not provided to ASOF Join" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" right_match_col = Column(right_identifiers_rename_map[right_match_col]) # ASOF Join requires the use of match_condition - assert match_comparator, "match_comparator was not provided to ASOF Join" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).__eq__(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right=right_snowpark_dataframe_ref.snowpark_dataframe, + on=on, how=how, match_condition=getattr(left_match_col, match_comparator.value)( right_match_col @@ -1224,6 +1231,12 @@ def join( right_snowpark_dataframe_ref.snowpark_dataframe, how=how ) else: + # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).equal_null(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right_snowpark_dataframe_ref.snowpark_dataframe, on, how ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 3bf1062107e..e7a96b49ef1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -520,12 +520,15 @@ def single_pivot_helper( data_column_snowflake_quoted_identifiers: new data column snowflake quoted identifiers this pivot result data_column_pandas_labels: new data column pandas labels for this pivot result """ - snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {}) - if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func): + snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0) + if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func( + snowflake_agg_func.snowpark_aggregation + ): # TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations raise ErrorMessage.not_implemented( f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments." ) + snowpark_aggr_func = snowflake_agg_func.snowpark_aggregation pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair @@ -1231,17 +1234,19 @@ def get_margin_aggregation( Returns: Snowpark column expression for the aggregation function result. """ - resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}) + resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}, axis=0) # This would have been resolved during the original pivot at an early stage. assert resolved_aggfunc is not None, "resolved_aggfunc is None" - aggfunc_expr = resolved_aggfunc(snowflake_quoted_identifier) + aggregation_expression = resolved_aggfunc.snowpark_aggregation( + snowflake_quoted_identifier + ) - if resolved_aggfunc == sum_: - aggfunc_expr = coalesce(aggfunc_expr, pandas_lit(0)) + if resolved_aggfunc.snowpark_aggregation == sum_: + aggregation_expression = coalesce(aggregation_expression, pandas_lit(0)) - return aggfunc_expr + return aggregation_expression def expand_pivot_result_with_pivot_table_margins_no_groupby_columns( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index de83e0429bf..ba8ceedec5e 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -649,8 +649,6 @@ def perform_asof_join_on_frame( left=preserving_frame, right=referenced_frame, how="asof", - left_on=[], - right_on=[], left_match_col=left_timecol_snowflake_quoted_identifier, right_match_col=right_timecol_snowflake_quoted_identifier, match_comparator=( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index d38584c14de..e19a6de37ba 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -567,9 +567,7 @@ def __new__( attrs (Dict[str, Any]): The attributes of the class. Returns: - Union[snowflake.snowpark.modin.pandas.series.Series, - snowflake.snowpark.modin.pandas.dataframe.DataFrame, - snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, + Union[snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, snowflake.snowpark.modin.pandas.resample.Resampler, snowflake.snowpark.modin.pandas.window.Window, snowflake.snowpark.modin.pandas.window.Rolling]: diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index f8629e664f3..3b714087535 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -525,7 +525,7 @@ def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: The column after conversion to the specified timezone """ if tz is None: - return convert_timezone(pandas_lit("UTC"), column) + return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column)) else: if isinstance(tz, dt.tzinfo): tz_name = tz.tzname(None) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 9f01954ab2c..34a3376fcc1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -1276,7 +1276,7 @@ def check_snowpark_pandas_object_in_arg(arg: Any) -> bool: if check_snowpark_pandas_object_in_arg(v): return True else: - from snowflake.snowpark.modin.pandas import DataFrame, Series + from modin.pandas import DataFrame, Series return isinstance(arg, (DataFrame, Series)) @@ -1519,9 +1519,15 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta: downcast_pandas_df.columns, cached_snowpark_pandas_types ): if snowpark_pandas_type is not None and snowpark_pandas_type == timedelta_t: - downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( - convert_str_to_timedelta - ) + # By default, pandas warns, "A value is trying to be set on a + # copy of a slice from a DataFrame" here because we are + # assigning a column to downcast_pandas_df, which is a copy of + # a slice of pandas_df. We don't care what happens to pandas_df, + # so the warning isn't useful to us. + with native_pd.option_context("mode.chained_assignment", None): + downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( + convert_str_to_timedelta + ) # Step 7. postprocessing for return types if index_only: diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b5022bff46b..a3981379aaf 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -149,8 +149,6 @@ ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, - GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, - GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, AggFuncInfo, AggFuncWithLabel, AggregateColumnOpParameters, @@ -161,7 +159,6 @@ convert_agg_func_arg_to_col_agg_func_map, drop_non_numeric_data_columns, generate_column_agg_info, - generate_rowwise_aggregation_function, get_agg_func_to_col_map, get_pandas_aggr_func_name, get_snowflake_agg_func, @@ -2040,8 +2037,8 @@ def binary_op( # Native pandas does not support binary operations between a Series and a list-like object. from modin.pandas import Series + from modin.pandas.dataframe import DataFrame - from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar # fail explicitly for unsupported scenarios @@ -3556,42 +3553,22 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) - # Get the column aggregation functions used to check if the function - # preserves Snowpark pandas types. - agg_col_funcs = [] - for _, func in column_to_agg_func.items(): - if is_list_like(func) and not is_named_tuple(func): - for fn in func: - agg_col_funcs.append(fn.func) - else: - agg_col_funcs.append(func.func) # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] new_data_column_quoted_identifiers = [] new_data_column_snowpark_pandas_types = [] - for i in range(len(agg_col_ops)): - col_agg_op = agg_col_ops[i] - col_agg_func = agg_col_funcs[i] - new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label) + for agg_col_op in agg_col_ops: + new_data_column_pandas_labels.append(agg_col_op.agg_pandas_label) new_data_column_quoted_identifiers.append( - col_agg_op.agg_snowflake_quoted_identifier + agg_col_op.agg_snowflake_quoted_identifier + ) + new_data_column_snowpark_pandas_types.append( + agg_col_op.data_type + if isinstance(agg_col_op.data_type, SnowparkPandasType) + and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types + else None ) - if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: - new_data_column_snowpark_pandas_types.append( - col_agg_op.data_type - if isinstance(col_agg_op.data_type, SnowparkPandasType) - else None - ) - elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: - # In the case where the aggregation overrides the type of the output data column - # (e.g. any always returns boolean data columns), set the output Snowpark pandas type - # of the given column to None - new_data_column_snowpark_pandas_types.append(None) # type: ignore - else: - self._raise_not_implemented_error_for_timedelta() - new_data_column_snowpark_pandas_types = None # type: ignore - # The ordering of the named aggregations is changed by us when we process # the agg_kwargs into the func dict (named aggregations on the same # column are moved to be contiguous, see groupby.py::aggregate for an @@ -3644,7 +3621,7 @@ def convert_func_to_agg_func_info( ), agg_pandas_label=None, agg_snowflake_quoted_identifier=row_position_quoted_identifier, - snowflake_agg_func=min_, + snowflake_agg_func=get_snowflake_agg_func("min", agg_kwargs={}, axis=0), ordering_columns=internal_frame.ordering_columns, ) agg_col_ops.append(row_position_agg_column_op) @@ -5657,8 +5634,6 @@ def agg( args: the arguments passed for the aggregation kwargs: keyword arguments passed for the aggregation function. """ - self._raise_not_implemented_error_for_timedelta() - numeric_only = kwargs.get("numeric_only", False) # Call fallback if the aggregation function passed in the arg is currently not supported # by snowflake engine. @@ -5704,6 +5679,11 @@ def agg( not is_list_like(value) for value in func.values() ) if axis == 1: + if any( + isinstance(t, TimedeltaType) + for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values() + ): + ErrorMessage.not_implemented_for_timedelta("agg(axis=1)") if self.is_multiindex(): # TODO SNOW-1010307 fix axis=1 behavior with MultiIndex ErrorMessage.not_implemented( @@ -5761,9 +5741,9 @@ def agg( pandas_column_labels=frame.data_column_pandas_labels, ) if agg_arg in ("idxmin", "idxmax") - else generate_rowwise_aggregation_function(agg_arg, kwargs)( - *(col(c) for c in data_col_identifiers) - ) + else get_snowflake_agg_func( + agg_arg, kwargs, axis=1 + ).snowpark_aggregation(*(col(c) for c in data_col_identifiers)) for agg_arg in agg_args } pandas_labels = list(agg_col_map.keys()) @@ -5883,7 +5863,13 @@ def generate_agg_qc( index_column_snowflake_quoted_identifiers=[ agg_name_col_quoted_identifier ], - data_column_types=None, + data_column_types=[ + col.data_type + if isinstance(col.data_type, SnowparkPandasType) + and col.snowflake_agg_func.preserves_snowpark_pandas_types + else None + for col in col_agg_infos + ], index_column_types=None, ) return SnowflakeQueryCompiler(single_agg_dataframe) @@ -7395,18 +7381,9 @@ def merge_asof( SnowflakeQueryCompiler """ # TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation - if ( - by - or left_by - or right_by - or left_index - or right_index - or tolerance - or suffixes != ("_x", "_y") - ): + if tolerance or suffixes != ("_x", "_y"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method does not currently support parameters " - + "'by', 'left_by', 'right_by', 'left_index', 'right_index', " + "'suffixes', or 'tolerance'" ) if direction not in ("backward", "forward"): @@ -7414,9 +7391,24 @@ def merge_asof( "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ) + if direction == "backward": + match_comparator = ( + MatchComparator.GREATER_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.GREATER_THAN + ) + else: + match_comparator = ( + MatchComparator.LESS_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.LESS_THAN + ) + left_frame = self._modin_frame right_frame = right._modin_frame - left_keys, right_keys = join_utils.get_join_keys( + # Get the left and right matching key and quoted identifier corresponding to the match_condition + # There will only be matching key/identifier for each table as there is only a single match condition + left_match_keys, right_match_keys = join_utils.get_join_keys( left=left_frame, right=right_frame, on=on, @@ -7425,42 +7417,62 @@ def merge_asof( left_index=left_index, right_index=right_index, ) - left_match_col = ( + left_match_identifier = ( left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - left_keys + left_match_keys )[0][0] ) - right_match_col = ( + right_match_identifier = ( right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - right_keys + right_match_keys )[0][0] ) - - if direction == "backward": - match_comparator = ( - MatchComparator.GREATER_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.GREATER_THAN + coalesce_config = join_utils.get_coalesce_config( + left_keys=left_match_keys, + right_keys=right_match_keys, + external_join_keys=[], + ) + + # Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition + if by or (left_by and right_by): + left_on_keys, right_on_keys = join_utils.get_join_keys( + left=left_frame, + right=right_frame, + on=by, + left_on=left_by, + right_on=right_by, + ) + left_on_identifiers = [ + ids[0] + for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + left_on_keys + ) + ] + right_on_identifiers = [ + ids[0] + for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + right_on_keys + ) + ] + coalesce_config.extend( + join_utils.get_coalesce_config( + left_keys=left_on_keys, + right_keys=right_on_keys, + external_join_keys=[], + ) ) else: - match_comparator = ( - MatchComparator.LESS_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.LESS_THAN - ) - - coalesce_config = join_utils.get_coalesce_config( - left_keys=left_keys, right_keys=right_keys, external_join_keys=[] - ) + left_on_identifiers = [] + right_on_identifiers = [] joined_frame, _ = join_utils.join( left=left_frame, right=right_frame, + left_on=left_on_identifiers, + right_on=right_on_identifiers, how="asof", - left_on=[left_match_col], - right_on=[right_match_col], - left_match_col=left_match_col, - right_match_col=right_match_col, + left_match_col=left_match_identifier, + right_match_col=right_match_identifier, match_comparator=match_comparator, join_key_coalesce_config=coalesce_config, sort=True, @@ -9129,7 +9141,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler": SnowflakeQueryCompiler Transposed new QueryCompiler object. """ - self._raise_not_implemented_error_for_timedelta() + if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1: + # In this case, transpose may lose types. + self._raise_not_implemented_error_for_timedelta() frame = self._modin_frame @@ -12513,8 +12527,6 @@ def _quantiles_single_col( column would allow us to create an accurate row position column, but would require a potentially expensive JOIN operator afterwards to apply the correct index labels. """ - self._raise_not_implemented_error_for_timedelta() - assert len(self._modin_frame.data_column_pandas_labels) == 1 if index is not None: @@ -12579,7 +12591,7 @@ def _quantiles_single_col( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_identifier], - data_column_types=None, + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, index_column_types=None, ) # We cannot call astype() directly to convert an index column, so we replicate @@ -13613,6 +13625,16 @@ def _window_agg( } ).frame else: + snowflake_agg_func = get_snowflake_agg_func(agg_func, agg_kwargs, axis=0) + if snowflake_agg_func is None: + # We don't have test coverage for this situation because we + # test individual rolling and expanding methods we've implemented, + # like rolling_sum(), but other rolling methods raise + # NotImplementedError immediately. We also don't support rolling + # agg(), which might take us here. + ErrorMessage.not_implemented( # pragma: no cover + f"Window aggregation does not support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)}" + ) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { # If aggregation is count use count on row_position_quoted_identifier @@ -13623,7 +13645,7 @@ def _window_agg( if agg_func == "count" else count(col(quoted_identifier)).over(window_expr) >= min_periods, - get_snowflake_agg_func(agg_func, agg_kwargs)( + snowflake_agg_func.snowpark_aggregation( # Expanding is cumulative so replace NULL with 0 for sum aggregation builtin("zeroifnull")(col(quoted_identifier)) if window_func == WindowFunction.EXPANDING @@ -14577,8 +14599,6 @@ def idxmax( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -14603,8 +14623,6 @@ def idxmin( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -16668,6 +16686,7 @@ def dt_tz_localize( tz: Union[str, tzinfo], ambiguous: str = "raise", nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Localize tz-naive to tz-aware. @@ -16675,39 +16694,50 @@ def dt_tz_localize( tz : str, pytz.timezone, optional ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise" nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise" + include_index: Whether to include the index columns in the operation. Returns: BaseQueryCompiler New QueryCompiler containing values with localized time zone. """ + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if not include_index: + method_name = "Series.dt.tz_localize" + else: + assert is_datetime64_any_dtype(dtype), "column must be datetime" + method_name = "DatetimeIndex.tz_localize" + if not isinstance(ambiguous, str) or ambiguous != "raise": - ErrorMessage.parameter_not_implemented_error( - "ambiguous", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if not isinstance(nonexistent, str) or nonexistent != "raise": - ErrorMessage.parameter_not_implemented_error( - "nonexistent", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_localize_column(column, tz) + lambda column: tz_localize_column(column, tz), + include_index, ) ) - def dt_tz_convert(self, tz: Union[str, tzinfo]) -> "SnowflakeQueryCompiler": + def dt_tz_convert( + self, + tz: Union[str, tzinfo], + include_index: bool = False, + ) -> "SnowflakeQueryCompiler": """ Convert time-series data to the specified time zone. Args: tz : str, pytz.timezone + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler containing values with converted time zone. """ return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_convert_column(column, tz) + lambda column: tz_convert_column(column, tz), + include_index, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/base.py b/src/snowflake/snowpark/modin/plugin/docstrings/base.py index af50e0379dd..4044f7b675f 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/base.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/base.py @@ -2832,6 +2832,7 @@ def shift(): """ Implement shared functionality between DataFrame and Series for shift. axis argument is only relevant for Dataframe, and should be 0 for Series. + Args: periods : int | Sequence[int] Number of periods to shift. Can be positive or negative. If an iterable of ints, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 6223e9dd273..f7e93e6c2df 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -1749,7 +1749,7 @@ def info(): ... 'COL2': ['A', 'B', 'C']}) >>> df.info() # doctest: +NORMALIZE_WHITESPACE - + SnowflakeIndex Data columns (total 2 columns): # Column Non-Null Count Dtype diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index aeca9d6e305..ecef6e843ba 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -60,7 +60,6 @@ validate_percentile, ) -import snowflake.snowpark.modin.pandas as spd from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, @@ -88,8 +87,6 @@ def register_base_override(method_name: str): for directly overriding methods on BasePandasDataset, we mock this by performing the override on DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary to allow both the docstring extension and method dispatch to work properly. - - Methods annotated here also are automatically instrumented with Snowpark pandas telemetry. """ def decorator(base_method: Any): @@ -103,10 +100,7 @@ def decorator(base_method: Any): series_method = series_method.fget if series_method is None or series_method is parent_method: register_series_accessor(method_name)(base_method) - # TODO: SNOW-1063346 - # Since we still use the vendored version of DataFrame and the overrides for the top-level - # namespace haven't been performed yet, we need to set properties on the vendored version - df_method = getattr(spd.dataframe.DataFrame, method_name, None) + df_method = getattr(pd.DataFrame, method_name, None) if isinstance(df_method, property): df_method = df_method.fget if df_method is None or df_method is parent_method: @@ -176,6 +170,22 @@ def filter( pass # pragma: no cover +@register_base_not_implemented() +def interpolate( + self, + method="linear", + *, + axis=0, + limit=None, + inplace=False, + limit_direction: str | None = None, + limit_area=None, + downcast=lib.no_default, + **kwargs, +): # noqa: PR01, RT01, D200 + pass + + @register_base_not_implemented() def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200 pass # pragma: no cover @@ -813,7 +823,7 @@ def _binary_op( **kwargs, ) - from snowflake.snowpark.modin.pandas.dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # Modin Bug: https://github.com/modin-project/modin/issues/7236 # For a Series interacting with a DataFrame, always return a DataFrame diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py index 5ce836061ab..62c9cab4dc1 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py @@ -7,20 +7,1443 @@ pandas, such as `DataFrame.memory_usage`. """ -from typing import Any, Union +from __future__ import annotations +import collections +import datetime +import functools +import itertools +import sys +import warnings +from typing import ( + IO, + Any, + Callable, + Hashable, + Iterable, + Iterator, + Literal, + Mapping, + Sequence, +) + +import modin.pandas as pd +import numpy as np import pandas as native_pd -from modin.pandas import DataFrame -from pandas._typing import Axis, PythonFuncType -from pandas.core.dtypes.common import is_dict_like, is_list_like +from modin.pandas import DataFrame, Series +from modin.pandas.base import BasePandasDataset +from pandas._libs.lib import NoDefault, no_default +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, + Axis, + CompressionOptions, + FilePath, + FillnaOptions, + IgnoreRaise, + IndexLabel, + Level, + PythonFuncType, + Renamer, + Scalar, + StorageOptions, + Suffixes, + WriteBuffer, +) +from pandas.core.common import apply_if_callable, is_bool_indexer +from pandas.core.dtypes.common import ( + infer_dtype_from_object, + is_bool_dtype, + is_dict_like, + is_list_like, + is_numeric_dtype, +) +from pandas.core.dtypes.inference import is_hashable, is_integer +from pandas.core.indexes.frozen import FrozenList +from pandas.io.formats.printing import pprint_thing +from pandas.util._validators import validate_bool_kwarg + +from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor +from snowflake.snowpark.modin.pandas.groupby import ( + DataFrameGroupBy, + validate_groupby_args, +) +from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( + SnowparkPandasRowPartitionIterator, +) +from snowflake.snowpark.modin.pandas.utils import ( + create_empty_native_pandas_frame, + from_non_pandas, + from_pandas, + is_scalar, + raise_if_native_pandas_objects, + replace_external_data_keys_with_empty_pandas_series, + replace_external_data_keys_with_query_compiler, +) +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + is_snowflake_agg_func, +) +from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated +from snowflake.snowpark.modin.plugin._typing import ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ( + ErrorMessage, + dataframe_not_implemented, +) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import ( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE, + DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE, + DF_SETITEM_SLICE_AS_SCALAR_VALUE, +) +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + _inherit_docstrings, + hashable, + validate_int_kwarg, +) +from snowflake.snowpark.udf import UserDefinedFunction + + +def register_dataframe_not_implemented(): + def decorator(base_method: Any): + func = dataframe_not_implemented()(base_method) + register_dataframe_accessor(base_method.__name__)(func) + return func + + return decorator + + +# === UNIMPLEMENTED METHODS === +# The following methods are not implemented in Snowpark pandas, and must be overridden on the +# frontend. These methods fall into a few categories: +# 1. Would work in Snowpark pandas, but we have not tested it. +# 2. Would work in Snowpark pandas, but requires more SQL queries than we are comfortable with. +# 3. Requires materialization (usually via a frontend _default_to_pandas call). +# 4. Performs operations on a native pandas Index object that are nontrivial for Snowpark pandas to manage. + + +# Avoid overwriting builtin `map` by accident +@register_dataframe_accessor("map") +@dataframe_not_implemented() +def _map(self, func, na_action: str | None = None, **kwargs) -> DataFrame: + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def boxplot( + self, + column=None, + by=None, + ax=None, + fontsize=None, + rot=0, + grid=True, + figsize=None, + layout=None, + return_type=None, + backend=None, + **kwargs, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def combine( + self, other, func, fill_value=None, overwrite=True +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def corrwith( + self, other, axis=0, drop=False, method="pearson", numeric_only=False +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def cov( + self, min_periods=None, ddof: int | None = 1, numeric_only=False +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def dot(self, other): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def hist( + self, + column=None, + by=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax=None, + sharex=False, + sharey=False, + figsize=None, + layout=None, + bins=10, + **kwds, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def isetitem(self, loc, value): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def prod( + self, + axis=None, + skipna=True, + numeric_only=False, + min_count=0, + **kwargs, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +register_dataframe_accessor("product")(prod) + + +@register_dataframe_not_implemented() +def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def reindex_like( + self, + other, + method=None, + copy: bool | None = None, + limit=None, + tolerance=None, +) -> DataFrame: # pragma: no cover + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_gbq( + self, + destination_table, + project_id=None, + chunksize=None, + reauth=False, + if_exists="fail", + auth_local_webserver=True, + table_schema=None, + location=None, + progress_bar=True, + credentials=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_html( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + max_rows=None, + max_cols=None, + show_dimensions=False, + decimal=".", + bold_rows=True, + classes=None, + escape=True, + notebook=False, + border=None, + table_id=None, + render_links=False, + encoding=None, +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_parquet( + self, + path=None, + engine="auto", + compression="snappy", + index=None, + partition_cols=None, + storage_options: StorageOptions = None, + **kwargs, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_period( + self, freq=None, axis=0, copy=True +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_records( + self, index=True, column_dtypes=None, index_dtypes=None +): # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_stata( + self, + path: FilePath | WriteBuffer[bytes], + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime.datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + version: int | None = 114, + convert_strl: Sequence[Hashable] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def to_xml( + self, + path_or_buffer=None, + index=True, + root_name="data", + row_name="row", + na_rep=None, + attr_cols=None, + elem_cols=None, + namespaces=None, + prefix=None, + encoding="utf-8", + xml_declaration=True, + pretty_print=True, + parser="lxml", + stylesheet=None, + compression="infer", + storage_options=None, +): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __delitem__(self, key): + pass # pragma: no cover + + +@register_dataframe_accessor("attrs") +@dataframe_not_implemented() +@property +def attrs(self): # noqa: RT01, D200 + pass # pragma: no cover + + +@register_dataframe_accessor("style") +@dataframe_not_implemented() +@property +def style(self): # noqa: RT01, D200 + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __reduce__(self): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __divmod__(self, other): + pass # pragma: no cover + + +@register_dataframe_not_implemented() +def __rdivmod__(self, other): + pass # pragma: no cover + + +# The from_dict and from_records accessors are class methods and cannot be overridden via the +# extensions module, as they need to be foisted onto the namespace directly because they are not +# routed through getattr. To this end, we manually set DataFrame.from_dict to our new method. +@dataframe_not_implemented() +def from_dict( + cls, data, orient="columns", dtype=None, columns=None +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +DataFrame.from_dict = from_dict + + +@dataframe_not_implemented() +def from_records( + cls, + data, + index=None, + exclude=None, + columns=None, + coerce_float=False, + nrows=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass # pragma: no cover + + +DataFrame.from_records = from_records + + +# === OVERRIDDEN METHODS === +# The below methods have their frontend implementations overridden compared to the version present +# in series.py. This is usually for one of the following reasons: +# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate +# and binary operations; further work is needed to refactor either our implementation or upstream +# modin's implementation. +# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already +# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details. +# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler +# layer is acceptable because we can force the method to raise NotImplementedError, but if a method +# defaults at the frontend, Modin raises a warning and performs the operation by coercing the +# dataset to a native pandas object. Removing these is tracked by +# https://github.com/modin-project/modin/issues/7104 + + +# Snowpark pandas overrides the constructor for two reasons: +# 1. To support the Snowpark pandas lazy index object +# 2. To avoid raising "UserWarning: Distributing object. This may take some time." +# when a literal is passed in as data. +@register_dataframe_accessor("__init__") +def __init__( + self, + data=None, + index=None, + columns=None, + dtype=None, + copy=None, + query_compiler=None, +) -> None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Siblings are other dataframes that share the same query compiler. We + # use this list to update inplace when there is a shallow copy. + from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native + + self._siblings = [] + + # Engine.subscribe(_update_engine) + if isinstance(data, (DataFrame, Series)): + self._query_compiler = data._query_compiler.copy() + if index is not None and any(i not in data.index for i in index): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if isinstance(data, Series): + # We set the column name if it is not in the provided Series + if data.name is None: + self.columns = [0] if columns is None else columns + # If the columns provided are not in the named Series, pandas clears + # the DataFrame and sets columns to the columns provided. + elif columns is not None and data.name not in columns: + self._query_compiler = from_pandas( + self.__constructor__(columns=columns) + )._query_compiler + if index is not None: + self._query_compiler = data.loc[index]._query_compiler + elif columns is None and index is None: + data._add_sibling(self) + else: + if columns is not None and any(i not in data.columns for i in columns): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if index is None: + index = slice(None) + if columns is None: + columns = slice(None) + self._query_compiler = data.loc[index, columns]._query_compiler + + # Check type of data and use appropriate constructor + elif query_compiler is None: + distributed_frame = from_non_pandas(data, index, columns, dtype) + if distributed_frame is not None: + self._query_compiler = distributed_frame._query_compiler + return + + if isinstance(data, native_pd.Index): + pass + elif is_list_like(data) and not is_dict_like(data): + old_dtype = getattr(data, "dtype", None) + values = [ + obj._to_pandas() if isinstance(obj, Series) else obj for obj in data + ] + if isinstance(data, np.ndarray): + data = np.array(values, dtype=old_dtype) + else: + try: + data = type(data)(values, dtype=old_dtype) + except TypeError: + data = values + elif is_dict_like(data) and not isinstance( + data, (native_pd.Series, Series, native_pd.DataFrame, DataFrame) + ): + if columns is not None: + data = {key: value for key, value in data.items() if key in columns} + + if len(data) and all(isinstance(v, Series) for v in data.values()): + from modin.pandas import concat + + new_qc = concat(data.values(), axis=1, keys=data.keys())._query_compiler + + if dtype is not None: + new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) + if index is not None: + new_qc = new_qc.reindex( + axis=0, labels=try_convert_index_to_native(index) + ) + if columns is not None: + new_qc = new_qc.reindex( + axis=1, labels=try_convert_index_to_native(columns) + ) + + self._query_compiler = new_qc + return + + data = { + k: v._to_pandas() if isinstance(v, Series) else v + for k, v in data.items() + } + pandas_df = native_pd.DataFrame( + data=try_convert_index_to_native(data), + index=try_convert_index_to_native(index), + columns=try_convert_index_to_native(columns), + dtype=dtype, + copy=copy, + ) + self._query_compiler = from_pandas(pandas_df)._query_compiler + else: + self._query_compiler = query_compiler + + +@register_dataframe_accessor("__dataframe__") +def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Get a Modin DataFrame that implements the dataframe exchange protocol. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + nan_as_null : bool, default: False + A keyword intended for the consumer to tell the producer + to overwrite null values in the data with ``NaN`` (or ``NaT``). + This currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + allow_copy : bool, default: True + A keyword that defines whether or not the library is allowed + to make a copy of the data. For example, copying data would be necessary + if a library supports strided buffers, given that this protocol + specifies contiguous buffers. Currently, if the flag is set to ``False`` + and a copy is needed, a ``RuntimeError`` will be raised. + + Returns + ------- + ProtocolDataframe + A dataframe object following the dataframe protocol specification. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented( + "Snowpark pandas does not support the DataFrame interchange " + + "protocol method `__dataframe__`. To use Snowpark pandas " + + "DataFrames with third-party libraries that try to call the " + + "`__dataframe__` method, please convert this Snowpark pandas " + + "DataFrame to pandas with `to_pandas()`." + ) + + return self._query_compiler.to_dataframe( + nan_as_null=nan_as_null, allow_copy=allow_copy + ) + + +# Snowpark pandas defaults to axis=1 instead of axis=0 for these; we need to investigate if the same should +# apply to upstream Modin. +@register_dataframe_accessor("__and__") +def __and__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__and__", other, axis=1) + + +@register_dataframe_accessor("__rand__") +def __rand__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__rand__", other, axis=1) + + +@register_dataframe_accessor("__or__") +def __or__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__or__", other, axis=1) + + +@register_dataframe_accessor("__ror__") +def __ror__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__ror__", other, axis=1) + + +# Upstream Modin defaults to pandas in some cases. +@register_dataframe_accessor("apply") +def apply( + self, + func: AggFuncType | UserDefinedFunction, + axis: Axis = 0, + raw: bool = False, + result_type: Literal["expand", "reduce", "broadcast"] | None = None, + args=(), + **kwargs, +): + """ + Apply a function along an axis of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.apply( + func, + axis, + raw=raw, + result_type=result_type, + args=args, + **kwargs, + ) + if not isinstance(query_compiler, type(self._query_compiler)): + # A scalar was returned + return query_compiler + + # If True, it is an unamed series. + # Theoretically, if df.apply returns a Series, it will only be an unnamed series + # because the function is supposed to be series -> scalar. + if query_compiler._modin_frame.is_unnamed_series(): + return Series(query_compiler=query_compiler) + else: + return self.__constructor__(query_compiler=query_compiler) + + +# Snowpark pandas uses a separate QC method, while modin directly calls map. +@register_dataframe_accessor("applymap") +def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not callable(func): + raise TypeError(f"{func} is not callable") + return self.__constructor__( + query_compiler=self._query_compiler.applymap( + func, na_action=na_action, **kwargs + ) + ) + + +# We need to override _get_columns to satisfy +# tests/unit/modin/test_type_annotations.py::test_properties_snow_1374293[_get_columns-type_hints1] +# since Modin doesn't provide this type hint. +def _get_columns(self) -> native_pd.Index: + """ + Get the columns for this Snowpark pandas ``DataFrame``. + + Returns + ------- + Index + The all columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.columns + + +# Snowpark pandas wraps this in an update_in_place +def _set_columns(self, new_columns: Axes) -> None: + """ + Set the columns for this Snowpark pandas ``DataFrame``. + + Parameters + ---------- + new_columns : + The new columns to set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + self._update_inplace( + new_query_compiler=self._query_compiler.set_columns(new_columns) + ) + + +register_dataframe_accessor("columns")(property(_get_columns, _set_columns)) + + +# Snowpark pandas does preprocessing for numeric_only (should be pushed to QC). +@register_dataframe_accessor("corr") +def corr( + self, + method: str | Callable = "pearson", + min_periods: int | None = None, + numeric_only: bool = False, +): # noqa: PR01, RT01, D200 + """ + Compute pairwise correlation of columns, excluding NA/null values. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + corr_df = self + if numeric_only: + corr_df = self.drop( + columns=[ + i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) + ] + ) + return self.__constructor__( + query_compiler=corr_df._query_compiler.corr( + method=method, + min_periods=min_periods, + ) + ) + + +# Snowpark pandas does not respect `ignore_index`, and upstream Modin does not respect `how`. +@register_dataframe_accessor("dropna") +def dropna( + self, + *, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, +): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super(DataFrame, self)._dropna( + axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace + ) + + +# Snowpark pandas uses `self_is_series`, while upstream Modin uses `squeeze_self` and `squeeze_value`. +@register_dataframe_accessor("fillna") +def fillna( + self, + value: Hashable | Mapping | Series | DataFrame = None, + *, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, +) -> DataFrame | None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super(DataFrame, self).fillna( + self_is_series=False, + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + + +# Snowpark pandas does different validation and returns a custom GroupBy object. +@register_dataframe_accessor("groupby") +def groupby( + self, + by=None, + axis: Axis | NoDefault = no_default, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool | NoDefault = no_default, + dropna: bool = True, +): + """ + Group ``DataFrame`` using a mapper or by a ``Series`` of columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if axis is not no_default: + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.groupby with axis=1 is deprecated. Do " + + "`frame.T.groupby(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + "The 'axis' keyword in DataFrame.groupby is deprecated and " + + "will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + validate_groupby_args(by, level, observed) + + axis = self._get_axis_number(axis) + + if axis != 0 and as_index is False: + raise ValueError("as_index=False only valid for axis=0") + + idx_name = None + + if ( + not isinstance(by, Series) + and is_list_like(by) + and len(by) == 1 + # if by is a list-like of (None,), we have to keep it as a list because + # None may be referencing a column or index level whose label is + # `None`, and by=None wold mean that there is no `by` param. + and by[0] is not None + ): + by = by[0] + + if hashable(by) and ( + not callable(by) and not isinstance(by, (native_pd.Grouper, FrozenList)) + ): + idx_name = by + elif isinstance(by, Series): + idx_name = by.name + if by._parent is self: + # if the SnowSeries comes from the current dataframe, + # convert it to labels directly for easy processing + by = by.name + elif is_list_like(by): + if axis == 0 and all( + ( + (hashable(o) and (o in self)) + or isinstance(o, Series) + or (is_list_like(o) and len(o) == len(self.shape[axis])) + ) + for o in by + ): + # plit 'by's into those that belongs to the self (internal_by) + # and those that doesn't (external_by). For SnowSeries that belongs + # to current DataFrame, we convert it to labels for easy process. + internal_by, external_by = [], [] + + for current_by in by: + if hashable(current_by): + internal_by.append(current_by) + elif isinstance(current_by, Series): + if current_by._parent is self: + internal_by.append(current_by.name) + else: + external_by.append(current_by) # pragma: no cover + else: + external_by.append(current_by) + + by = internal_by + external_by + + return DataFrameGroupBy( + self, + by, + axis, + level, + as_index, + sort, + group_keys, + idx_name, + observed=observed, + dropna=dropna, + ) + + +# Upstream Modin uses a proxy DataFrameInfo object +@register_dataframe_accessor("info") +def info( + self, + verbose: bool | None = None, + buf: IO[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool | None = None, + null_counts: bool | None = None, +): # noqa: PR01, D200 + """ + Print a concise summary of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def put_str(src, output_len=None, spaces=2): + src = str(src) + return src.ljust(output_len if output_len else len(src)) + " " * spaces + + def format_size(num): + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if num < 1024.0: + return f"{num:3.1f} {x}" + num /= 1024.0 + return f"{num:3.1f} PB" + + output = [] + + type_line = str(type(self)) + index_line = "SnowflakeIndex" + columns = self.columns + columns_len = len(columns) + dtypes = self.dtypes + dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" + + if max_cols is None: + max_cols = 100 + + exceeds_info_cols = columns_len > max_cols + + if buf is None: + buf = sys.stdout + + if null_counts is None: + null_counts = not exceeds_info_cols + + if verbose is None: + verbose = not exceeds_info_cols + + if null_counts and verbose: + # We're gonna take items from `non_null_count` in a loop, which + # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here + # that will be faster. + non_null_count = self.count()._to_pandas() + + if memory_usage is None: + memory_usage = True + + def get_header(spaces=2): + output = [] + head_label = " # " + column_label = "Column" + null_label = "Non-Null Count" + dtype_label = "Dtype" + non_null_label = " non-null" + delimiter = "-" + + lengths = {} + lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) + lengths["column"] = max( + len(column_label), max(len(pprint_thing(col)) for col in columns) + ) + lengths["dtype"] = len(dtype_label) + dtype_spaces = ( + max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) + - lengths["dtype"] + ) + + header = put_str(head_label, lengths["head"]) + put_str( + column_label, lengths["column"] + ) + if null_counts: + lengths["null"] = max( + len(null_label), + max(len(pprint_thing(x)) for x in non_null_count) + len(non_null_label), + ) + header += put_str(null_label, lengths["null"]) + header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) + + output.append(header) + + delimiters = put_str(delimiter * lengths["head"]) + put_str( + delimiter * lengths["column"] + ) + if null_counts: + delimiters += put_str(delimiter * lengths["null"]) + delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) + output.append(delimiters) + + return output, lengths + + output.extend([type_line, index_line]) + + def verbose_repr(output): + columns_line = f"Data columns (total {len(columns)} columns):" + header, lengths = get_header() + output.extend([columns_line, *header]) + for i, col in enumerate(columns): + i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) + + to_append = put_str(f" {i}", lengths["head"]) + put_str( + col_s, lengths["column"] + ) + if null_counts: + non_null = pprint_thing(non_null_count[col]) + to_append += put_str(f"{non_null} non-null", lengths["null"]) + to_append += put_str(dtype, lengths["dtype"], spaces=0) + output.append(to_append) + + def non_verbose_repr(output): + output.append(columns._summary(name="Columns")) + + if verbose: + verbose_repr(output) + else: + non_verbose_repr(output) + + output.append(dtypes_line) + + if memory_usage: + deep = memory_usage == "deep" + mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() + mem_line = f"memory usage: {format_size(mem_usage_bytes)}" + + output.append(mem_line) + + output.append("") + buf.write("\n".join(output)) + + +# Snowpark pandas does different validation. +@register_dataframe_accessor("insert") +def insert( + self, + loc: int, + column: Hashable, + value: Scalar | AnyArrayLike, + allow_duplicates: bool | NoDefault = no_default, +) -> None: + """ + Insert column into ``DataFrame`` at specified location. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + raise_if_native_pandas_objects(value) + if allow_duplicates is no_default: + allow_duplicates = False + if not allow_duplicates and column in self.columns: + raise ValueError(f"cannot insert {column}, already exists") + + if not isinstance(loc, int): + raise TypeError("loc must be int") + + # If columns labels are multilevel, we implement following behavior (this is + # name native pandas): + # Case 1: if 'column' is tuple it's length must be same as number of levels + # otherwise raise error. + # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in + # empty strings to match the length of column levels in self frame. + if self.columns.nlevels > 1: + if isinstance(column, tuple) and len(column) != self.columns.nlevels: + # same error as native pandas. + raise ValueError("Item must have length equal to number of levels.") + if not isinstance(column, tuple): + # Fill empty strings to match length of levels + suffix = [""] * (self.columns.nlevels - 1) + column = tuple([column] + suffix) + + # Dictionary keys are treated as index column and this should be joined with + # index of target dataframe. This behavior is similar to 'value' being DataFrame + # or Series, so we simply create Series from dict data here. + if isinstance(value, dict): + value = Series(value, name=column) + + if isinstance(value, DataFrame) or ( + isinstance(value, np.ndarray) and len(value.shape) > 1 + ): + # Supported numpy array shapes are + # 1. (N, ) -> Ex. [1, 2, 3] + # 2. (N, 1) -> Ex> [[1], [2], [3]] + if value.shape[1] != 1: + if isinstance(value, DataFrame): + # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin + raise ValueError( + f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." + ) + else: + raise ValueError( + f"Expected a 1D array, got an array with shape {value.shape}" + ) + # Change numpy array shape from (N, 1) to (N, ) + if isinstance(value, np.ndarray): + value = value.squeeze(axis=1) + + if ( + is_list_like(value) + and not isinstance(value, (Series, DataFrame)) + and len(value) != self.shape[0] + and not 0 == self.shape[0] # dataframe holds no rows + ): + raise ValueError( + "Length of values ({}) does not match length of index ({})".format( + len(value), len(self) + ) + ) + if not -len(self.columns) <= loc <= len(self.columns): + raise IndexError( + f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" + ) + elif loc < 0: + raise ValueError("unbounded slice") + + join_on_index = False + if isinstance(value, (Series, DataFrame)): + value = value._query_compiler + join_on_index = True + elif is_list_like(value): + value = Series(value, name=column)._query_compiler + + new_query_compiler = self._query_compiler.insert(loc, column, value, join_on_index) + # In pandas, 'insert' operation is always inplace. + self._update_inplace(new_query_compiler=new_query_compiler) + + +# Snowpark pandas does more specialization based on the type of `values` +@register_dataframe_accessor("isin") +def isin( + self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] +) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(values, dict): + return super(DataFrame, self).isin(values) + elif isinstance(values, Series): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not values.index.is_unique: + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + elif isinstance(values, DataFrame): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not (values.columns.is_unique and values.index.is_unique): + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + else: + if not is_list_like(values): + # throw pandas compatible error + raise TypeError( + "only list-like or dict-like objects are allowed " + f"to be passed to {self.__class__.__name__}.isin(), " + f"you passed a '{type(values).__name__}'" + ) + return super(DataFrame, self).isin(values) + + +# Upstream Modin defaults to pandas for some arguments. +@register_dataframe_accessor("join") +def join( + self, + other: DataFrame | Series | Iterable[DataFrame | Series], + on: IndexLabel | None = None, + how: str = "left", + lsuffix: str = "", + rsuffix: str = "", + sort: bool = False, + validate: str | None = None, +) -> DataFrame: + """ + Join columns of another ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + for o in other if isinstance(other, list) else [other]: + raise_if_native_pandas_objects(o) + + # Similar to native pandas we implement 'join' using 'pd.merge' method. + # Following code is copied from native pandas (with few changes explained below) + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 + if isinstance(other, Series): + # Same error as native pandas. + if other.name is None: + raise ValueError("Other Series must have a name") + other = DataFrame(other) + elif is_list_like(other): + if any([isinstance(o, Series) and o.name is None for o in other]): + raise ValueError("Other Series must have a name") + + if isinstance(other, DataFrame): + if how == "cross": + return pd.merge( + self, + other, + how=how, + on=on, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + return pd.merge( + self, + other, + left_on=on, + how=how, + left_index=on is None, + right_index=True, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + else: # List of DataFrame/Series + # Same error as native pandas. + if on is not None: + raise ValueError( + "Joining multiple DataFrames only supported for joining on index" + ) + + # Same error as native pandas. + if rsuffix or lsuffix: + raise ValueError("Suffixes not supported when joining multiple DataFrames") + + # NOTE: These are not the differences between Snowpark pandas API and pandas behavior + # these are differences between native pandas join behavior when join + # frames have unique index or not. + + # In native pandas logic to join multiple DataFrames/Series is data + # dependent. Under the hood it will either use 'concat' or 'merge' API + # Case 1. If all objects being joined have unique index use 'concat' (axis=1) + # Case 2. Otherwise use 'merge' API by looping through objects left to right. + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 + + # Even though concat (axis=1) and merge are very similar APIs they have + # some differences which leads to inconsistent behavior in native pandas. + # 1. Treatment of un-named Series + # Case #1: Un-named series is allowed in concat API. Objects are joined + # successfully by assigning a number as columns name (see 'concat' API + # documentation for details on treatment of un-named series). + # Case #2: It raises 'ValueError: Other Series must have a name' + + # 2. how='right' + # Case #1: 'concat' API doesn't support right join. It raises + # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' + # Case #2: Merges successfully. + + # 3. Joining frames with duplicate labels but no conflict with other frames + # Example: self = DataFrame(... columns=["A", "B"]) + # other = [DataFrame(... columns=["C", "C"])] + # Case #1: 'ValueError: Indexes have overlapping values' + # Case #2: Merged successfully. -from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor -from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( - is_snowflake_agg_func, -) -from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage -from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage -from snowflake.snowpark.modin.utils import _inherit_docstrings, validate_int_kwarg + # In addition to this, native pandas implementation also leads to another + # type of inconsistency where left.join(other, ...) and + # left.join([other], ...) might behave differently for cases mentioned + # above. + # Example: + # import pandas as pd + # df = pd.DataFrame({"a": [4, 5]}) + # other = pd.Series([1, 2]) + # df.join([other]) # this is successful + # df.join(other) # this raises 'ValueError: Other Series must have a name' + + # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API + # to join multiple DataFrame/Series. So always follow the behavior + # documented as Case #2 above. + + joined = self + for frame in other: + if isinstance(frame, DataFrame): + overlapping_cols = set(joined.columns).intersection(set(frame.columns)) + if len(overlapping_cols) > 0: + # Native pandas raises: 'Indexes have overlapping values' + # We differ slightly from native pandas message to make it more + # useful to users. + raise ValueError( + f"Join dataframes have overlapping column labels: {overlapping_cols}" + ) + joined = pd.merge( + joined, + frame, + how=how, + left_index=True, + right_index=True, + validate=validate, + sort=sort, + suffixes=(None, None), + ) + return joined + + +# Snowpark pandas does extra error checking. +@register_dataframe_accessor("mask") +def mask( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.mask requires an axis parameter (0 or 1) when given a Series" + ) + + return super(DataFrame, self).mask( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + +# Snowpark pandas has a fix for a pandas behavior change. It is available in Modin 0.30.1 (SNOW-1552497). +@register_dataframe_accessor("melt") +def melt( + self, + id_vars=None, + value_vars=None, + var_name=None, + value_name="value", + col_level=None, + ignore_index=True, +): # noqa: PR01, RT01, D200 + """ + Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if id_vars is None: + id_vars = [] + if not is_list_like(id_vars): + id_vars = [id_vars] + if value_vars is None: + # Behavior of Index.difference changed in 2.2.x + # https://github.com/pandas-dev/pandas/pull/55113 + # This change needs upstream to Modin: + # https://github.com/modin-project/modin/issues/7206 + value_vars = self.columns.drop(id_vars) + if var_name is None: + columns_name = self._query_compiler.get_index_name(axis=1) + var_name = columns_name if columns_name is not None else "variable" + return self.__constructor__( + query_compiler=self._query_compiler.melt( + id_vars=id_vars, + value_vars=value_vars, + var_name=var_name, + value_name=value_name, + col_level=col_level, + ignore_index=ignore_index, + ) + ) + + +# Snowpark pandas does more thorough error checking. +@register_dataframe_accessor("merge") +def merge( + self, + right: DataFrame | Series, + how: str = "inner", + on: IndexLabel | None = None, + left_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None, + right_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool = True, + indicator: bool = False, + validate: str | None = None, +) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Raise error if native pandas objects are passed. + raise_if_native_pandas_objects(right) + + if isinstance(right, Series) and right.name is None: + raise ValueError("Cannot merge a Series without a name") + if not isinstance(right, (Series, DataFrame)): + raise TypeError( + f"Can only merge Series or DataFrame objects, a {type(right)} was passed" + ) + + if isinstance(right, Series): + right_column_nlevels = len(right.name) if isinstance(right.name, tuple) else 1 + else: + right_column_nlevels = right.columns.nlevels + if self.columns.nlevels != right_column_nlevels: + # This is deprecated in native pandas. We raise explicit error for this. + raise ValueError( + "Can not merge objects with different column levels." + + f" ({self.columns.nlevels} levels on the left," + + f" {right_column_nlevels} on the right)" + ) + + # Merge empty native pandas dataframes for error checking. Otherwise, it will + # require a lot of logic to be written. This takes care of raising errors for + # following scenarios: + # 1. Only 'left_index' is set to True. + # 2. Only 'right_index is set to True. + # 3. Only 'left_on' is provided. + # 4. Only 'right_on' is provided. + # 5. 'on' and 'left_on' both are provided + # 6. 'on' and 'right_on' both are provided + # 7. 'on' and 'left_index' both are provided + # 8. 'on' and 'right_index' both are provided + # 9. 'left_on' and 'left_index' both are provided + # 10. 'right_on' and 'right_index' both are provided + # 11. Length mismatch between 'left_on' and 'right_on' + # 12. 'left_index' is not a bool + # 13. 'right_index' is not a bool + # 14. 'on' is not None and how='cross' + # 15. 'left_on' is not None and how='cross' + # 16. 'right_on' is not None and how='cross' + # 17. 'left_index' is True and how='cross' + # 18. 'right_index' is True and how='cross' + # 19. Unknown label in 'on', 'left_on' or 'right_on' + # 20. Provided 'suffixes' is not sufficient to resolve conflicts. + # 21. Merging on column with duplicate labels. + # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} + # 23. conflict with existing labels for array-like join key + # 24. 'indicator' argument is not bool or str + # 25. indicator column label conflicts with existing data labels + create_empty_native_pandas_frame(self).merge( + create_empty_native_pandas_frame(right), + on=on, + how=how, + left_on=replace_external_data_keys_with_empty_pandas_series(left_on), + right_on=replace_external_data_keys_with_empty_pandas_series(right_on), + left_index=left_index, + right_index=right_index, + suffixes=suffixes, + indicator=indicator, + ) + + return self.__constructor__( + query_compiler=self._query_compiler.merge( + right._query_compiler, + how=how, + on=on, + left_on=replace_external_data_keys_with_query_compiler(self, left_on), + right_on=replace_external_data_keys_with_query_compiler(right, right_on), + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + copy=copy, + indicator=indicator, + validate=validate, + ) + ) @_inherit_docstrings(native_pd.DataFrame.memory_usage, apilink="pandas.DataFrame") @@ -62,6 +1485,125 @@ def memory_usage(self, index: bool = True, deep: bool = False) -> Any: return native_pd.Series([0] * len(columns), index=columns) +# Snowpark pandas handles `inplace` differently. +@register_dataframe_accessor("replace") +def replace( + self, + to_replace=None, + value=no_default, + inplace: bool = False, + limit=None, + regex: bool = False, + method: str | NoDefault = no_default, +): + """ + Replace values given in `to_replace` with `value`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + new_query_compiler = self._query_compiler.replace( + to_replace=to_replace, + value=value, + limit=limit, + regex=regex, + method=method, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Snowpark pandas interacts with the inplace flag differently. +@register_dataframe_accessor("rename") +def rename( + self, + mapper: Renamer | None = None, + *, + index: Renamer | None = None, + columns: Renamer | None = None, + axis: Axis | None = None, + copy: bool | None = None, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", +) -> DataFrame | None: + """ + Alter axes labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if mapper is None and index is None and columns is None: + raise TypeError("must pass an index to rename") + + if index is not None or columns is not None: + if axis is not None: + raise TypeError( + "Cannot specify both 'axis' and any of 'index' or 'columns'" + ) + elif mapper is not None: + raise TypeError( + "Cannot specify both 'mapper' and any of 'index' or 'columns'" + ) + else: + # use the mapper argument + if axis and self._get_axis_number(axis) == 1: + columns = mapper + else: + index = mapper + + if copy is not None: + WarningMessage.ignored_argument( + operation="dataframe.rename", + argument="copy", + message="copy parameter has been ignored with Snowflake execution engine", + ) + + if isinstance(index, dict): + index = Series(index) + + new_qc = self._query_compiler.rename( + index_renamer=index, columns_renamer=columns, level=level, errors=errors + ) + return self._create_or_update_from_compiler( + new_query_compiler=new_qc, inplace=inplace + ) + + +# Upstream modin converts aggfunc to a cython function if it's a string. +@register_dataframe_accessor("pivot_table") +def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc="mean", + fill_value=None, + margins=False, + dropna=True, + margins_name="All", + observed=False, + sort=True, +): + """ + Create a spreadsheet-style pivot table as a ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + result = self.__constructor__( + query_compiler=self._query_compiler.pivot_table( + index=index, + values=values, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + observed=observed, + sort=sort, + ) + ) + return result + + +# Snowpark pandas produces a different warning for materialization. @register_dataframe_accessor("plot") @property def plot( @@ -108,11 +1650,227 @@ def plot( return self._to_pandas().plot +# Upstream Modin defaults when other is a Series. +@register_dataframe_accessor("pow") +def pow( + self, other, axis="columns", level=None, fill_value=None +): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "pow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + +@register_dataframe_accessor("rpow") +def rpow( + self, other, axis="columns", level=None, fill_value=None +): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rpow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + +# Snowpark pandas does extra argument validation, and uses iloc instead of drop at the end. +@register_dataframe_accessor("select_dtypes") +def select_dtypes( + self, + include: ListLike | str | type | None = None, + exclude: ListLike | str | type | None = None, +) -> DataFrame: + """ + Return a subset of the ``DataFrame``'s columns based on the column dtypes. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This line defers argument validation to pandas, which will raise errors on our behalf in cases + # like if `include` and `exclude` are None, the same type is specified in both lists, or a string + # dtype (as opposed to object) is specified. + native_pd.DataFrame().select_dtypes(include, exclude) + + if include and not is_list_like(include): + include = [include] + elif include is None: + include = [] + if exclude and not is_list_like(exclude): + exclude = [exclude] + elif exclude is None: + exclude = [] + + sel = tuple(map(set, (include, exclude))) + + # The width of the np.int_/float_ alias differs between Windows and other platforms, so + # we need to include a workaround. + # https://github.com/numpy/numpy/issues/9464 + # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 + def check_sized_number_infer_dtypes(dtype): + if (isinstance(dtype, str) and dtype == "int") or (dtype is int): + return [np.int32, np.int64] + elif dtype == "float" or dtype is float: + return [np.float64, np.float32] + else: + return [infer_dtype_from_object(dtype)] + + include, exclude = map( + lambda x: set( + itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) + ), + sel, + ) + # We need to index on column position rather than label in case of duplicates + include_these = native_pd.Series(not bool(include), index=range(len(self.columns))) + exclude_these = native_pd.Series(not bool(exclude), index=range(len(self.columns))) + + def is_dtype_instance_mapper(dtype): + return functools.partial(issubclass, dtype.type) + + for i, dtype in enumerate(self.dtypes): + if include: + include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) + if exclude: + exclude_these[i] = not any(map(is_dtype_instance_mapper(dtype), exclude)) + + dtype_indexer = include_these & exclude_these + indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] + # We need to use iloc instead of drop in case of duplicate column names + return self.iloc[:, indicate] + + +# Snowpark pandas does extra validation on the `axis` argument. +@register_dataframe_accessor("set_axis") +def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, # ignored +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not is_scalar(axis): + raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") + return super(DataFrame, self).set_axis( + labels=labels, + # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. + axis=native_pd.DataFrame._get_axis_name(axis), + copy=copy, + ) + + +# Snowpark pandas needs extra logic for the lazy index class. +@register_dataframe_accessor("set_index") +def set_index( + self, + keys: IndexLabel + | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], + drop: bool = True, + append: bool = False, + inplace: bool = False, + verify_integrity: bool = False, +) -> None | DataFrame: + """ + Set the ``DataFrame`` index using existing columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if not isinstance(keys, list): + keys = [keys] + + # make sure key is either hashable, index, or series + label_or_series = [] + + missing = [] + columns = self.columns.tolist() + for key in keys: + raise_if_native_pandas_objects(key) + if isinstance(key, pd.Series): + label_or_series.append(key._query_compiler) + elif isinstance(key, (np.ndarray, list, Iterator)): + label_or_series.append(pd.Series(key)._query_compiler) + elif isinstance(key, (pd.Index, native_pd.MultiIndex)): + label_or_series += [s._query_compiler for s in self._to_series_list(key)] + else: + if not is_hashable(key): + raise TypeError( + f'The parameter "keys" may be a column key, one-dimensional array, or a list ' + f"containing only valid column keys and one-dimensional arrays. Received column " + f"of type {type(key)}" + ) + label_or_series.append(key) + found = key in columns + if columns.count(key) > 1: + raise ValueError(f"The column label '{key}' is not unique") + elif not found: + missing.append(key) + + if missing: + raise KeyError(f"None of {missing} are in the columns") + + new_query_compiler = self._query_compiler.set_index( + label_or_series, drop=drop, append=append + ) + + # TODO: SNOW-782633 improve this code once duplicate is supported + # this needs to pull all index which is inefficient + if verify_integrity and not new_query_compiler.index.is_unique: + duplicates = new_query_compiler.index[ + new_query_compiler.index.to_pandas().duplicated() + ].unique() + raise ValueError(f"Index has duplicate keys: {duplicates}") + + return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) + + +# Upstream Modin uses `len(self.index)` instead of `len(self)`, which gives an extra query. +@register_dataframe_accessor("shape") +@property +def shape(self) -> tuple[int, int]: + """ + Return a tuple representing the dimensionality of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return len(self), len(self.columns) + + +# Snowpark pands has rewrites to minimize queries from length checks. +@register_dataframe_accessor("squeeze") +def squeeze(self, axis: Axis | None = None): + """ + Squeeze 1 dimensional axis objects into scalars. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) if axis is not None else None + len_columns = self._query_compiler.get_axis_len(1) + if axis == 1 and len_columns == 1: + return Series(query_compiler=self._query_compiler) + if axis in [0, None]: + # get_axis_len(0) results in a sql query to count number of rows in current + # dataframe. We should only compute len_index if axis is 0 or None. + len_index = len(self) + if axis is None and (len_columns == 1 or len_index == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 0 and len_index == 1: + return Series(query_compiler=self.T._query_compiler) + return self.copy() + + # Upstream modin defines sum differently for series/DF, but we use the same implementation for both. @register_dataframe_accessor("sum") def sum( self, - axis: Union[Axis, None] = None, + axis: Axis | None = None, skipna: bool = True, numeric_only: bool = False, min_count: int = 0, @@ -130,6 +1888,70 @@ def sum( ) +# Snowpark pandas raises a warning where modin defaults to pandas. +@register_dataframe_accessor("stack") +def stack( + self, + level: int | str | list = -1, + dropna: bool | NoDefault = no_default, + sort: bool | NoDefault = no_default, + future_stack: bool = False, # ignored +): + """ + Stack the prescribed level(s) from columns to index. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if future_stack is not False: + WarningMessage.ignored_argument( # pragma: no cover + operation="DataFrame.stack", + argument="future_stack", + message="future_stack parameter has been ignored with Snowflake execution engine", + ) + if dropna is NoDefault: + dropna = True # pragma: no cover + if sort is NoDefault: + sort = True # pragma: no cover + + # This ensures that non-pandas MultiIndex objects are caught. + is_multiindex = len(self.columns.names) > 1 + if not is_multiindex or ( + is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.stack(level, dropna, sort) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.stack(level, dropna, sort) + ) + + +# Upstream modin doesn't pass `copy`, so we can't raise a warning for it. +# No need to override the `T` property since that can't take any extra arguments. +@register_dataframe_accessor("transpose") +def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 + """ + Transpose index and columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if copy: + WarningMessage.ignored_argument( + operation="transpose", + argument="copy", + message="Transpose ignore copy argument in Snowpark pandas API", + ) + + if args: + WarningMessage.ignored_argument( + operation="transpose", + argument="args", + message="Transpose ignores args in Snowpark pandas API", + ) + + return self.__constructor__(query_compiler=self._query_compiler.transpose()) + + +# Upstream modin implements transform in base.py, but we don't yet support Series.transform. @register_dataframe_accessor("transform") def transform( self, func: PythonFuncType, axis: Axis = 0, *args: Any, **kwargs: Any @@ -151,3 +1973,380 @@ def transform( raise ValueError("Function did not transform") return self.apply(func, axis, False, args=args, **kwargs) + + +# Upstream modin defaults to pandas for some arguments. +@register_dataframe_accessor("unstack") +def unstack( + self, + level: int | str | list = -1, + fill_value: int | str | dict = None, + sort: bool = True, +): + """ + Pivot a level of the (necessarily hierarchical) index labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This ensures that non-pandas MultiIndex objects are caught. + nlevels = self._query_compiler.nlevels() + is_multiindex = nlevels > 1 + + if not is_multiindex or ( + is_multiindex and is_list_like(level) and len(level) == nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=False + ) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=False + ) + ) + + +# Upstream modin does different validation and sorting. +@register_dataframe_accessor("value_counts") +def value_counts( + self, + subset: Sequence[Hashable] | None = None, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + dropna: bool = True, +): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return Series( + query_compiler=self._query_compiler.value_counts( + subset=subset, + normalize=normalize, + sort=sort, + ascending=ascending, + dropna=dropna, + ), + name="proportion" if normalize else "count", + ) + + +@register_dataframe_accessor("where") +def where( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.where requires an axis parameter (0 or 1) when given a Series" + ) + + return super(DataFrame, self).where( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + +# Snowpark pandas has a custom iterator. +@register_dataframe_accessor("iterrows") +def iterrows(self) -> Iterator[tuple[Hashable, Series]]: + """ + Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def iterrow_builder(s): + """Return tuple of the given `s` parameter name and the parameter themselves.""" + return s.name, s + + # Raise warning message since iterrows is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") + ) + + partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) + yield from partition_iterator + + +# Snowpark pandas has a custom iterator. +@register_dataframe_accessor("itertuples") +def itertuples( + self, index: bool = True, name: str | None = "Pandas" +) -> Iterable[tuple[Any, ...]]: + """ + Iterate over ``DataFrame`` rows as ``namedtuple``-s. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + + def itertuples_builder(s): + """Return the next namedtuple.""" + # s is the Series of values in the current row. + fields = [] # column names + data = [] # values under each column + + if index: + data.append(s.name) + fields.append("Index") + + # Fill column names and values. + fields.extend(list(self.columns)) + data.extend(s) + + if name is not None: + # Creating the namedtuple. + itertuple = collections.namedtuple(name, fields, rename=True) + return itertuple._make(data) + + # When the name is None, return a regular tuple. + return tuple(data) + + # Raise warning message since itertuples is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") + ) + return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) + + +# Snowpark pandas truncates the repr output. +@register_dataframe_accessor("__repr__") +def __repr__(self): + """ + Return a string representation for a particular ``DataFrame``. + + Returns + ------- + str + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + num_rows = native_pd.get_option("display.max_rows") or len(self) + # see _repr_html_ for comment, allow here also all column behavior + num_cols = native_pd.get_option("display.max_columns") or len(self.columns) + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") + result = repr(repr_df) + + # if truncated, add shape information + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # The split here is so that we don't repr pandas row lengths. + return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( + row_count, col_count + ) + else: + return result + + +# Snowpark pandas uses a different default `num_rows` value. +@register_dataframe_accessor("_repr_html_") +def _repr_html_(self): # pragma: no cover + """ + Return a html representation for a particular ``DataFrame``. + + Returns + ------- + str + + Notes + ----- + Supports pandas `display.max_rows` and `display.max_columns` options. + """ + num_rows = native_pd.get_option("display.max_rows") or 60 + # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow + # here value=0 which means display all columns. + num_cols = native_pd.get_option("display.max_columns") + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols) + result = repr_df._repr_html_() + + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # We split so that we insert our correct dataframe dimensions. + return ( + result.split("

")[0] + + f"

{row_count} rows × {col_count} columns

\n" + ) + else: + return result + + +# Upstream modin just uses `to_datetime` rather than `dataframe_to_datetime` on the query compiler. +@register_dataframe_accessor("_to_datetime") +def _to_datetime(self, **kwargs): + """ + Convert `self` to datetime. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_datetime` invocation. + + Returns + ------- + Series of datetime64 dtype + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._reduce_dimension( + query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) + ) + + +# Snowpark pandas has the extra `statement_params` argument. +@register_dataframe_accessor("_to_pandas") +def _to_pandas( + self, + *, + statement_params: dict[str, str] | None = None, + **kwargs: Any, +) -> native_pd.DataFrame: + """ + Convert Snowpark pandas DataFrame to pandas DataFrame + + Args: + statement_params: Dictionary of statement level parameters to be set while executing this action. + + Returns: + pandas DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.to_pandas(statement_params=statement_params, **kwargs) + + +# Snowpark pandas does more validation and error checking than upstream Modin, and uses different +# helper methods for dispatch. +@register_dataframe_accessor("__setitem__") +def __setitem__(self, key: Any, value: Any): + """ + Set attribute `value` identified by `key`. + + Args: + key: Key to set + value: Value to set + + Note: + In the case where value is any list like or array, pandas checks the array length against the number of rows + of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw + a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if + the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use + enlargement filling with the last value in the array. + + Returns: + None + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + key = apply_if_callable(key, self) + if isinstance(key, DataFrame) or ( + isinstance(key, np.ndarray) and len(key.shape) == 2 + ): + # This case uses mask's codepath to perform the set, but + # we need to duplicate the code here since we are passing + # an additional kwarg `cond_fillna_with_true` to the QC here. + # We need this additional kwarg, since if df.shape + # and key.shape do not align (i.e. df has more rows), + # mask's codepath would mask the additional rows in df + # while for setitem, we need to keep the original values. + if not isinstance(key, DataFrame): + if key.dtype != bool: + raise TypeError( + "Must pass DataFrame or 2-d ndarray with boolean values only" + ) + key = DataFrame(key) + key._query_compiler._shape_hint = "array" + + if value is not None: + value = apply_if_callable(value, self) + + if isinstance(value, np.ndarray): + value = DataFrame(value) + value._query_compiler._shape_hint = "array" + elif isinstance(value, pd.Series): + # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this + # error instead, since it is more descriptive. + raise ValueError( + "setitem with a 2D key does not support Series values." + ) + + if isinstance(value, BasePandasDataset): + value = value._query_compiler + + query_compiler = self._query_compiler.mask( + cond=key._query_compiler, + other=value, + axis=None, + level=None, + cond_fillna_with_true=True, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace=True) + + # Error Checking: + if (isinstance(key, pd.Series) or is_list_like(key)) and (isinstance(value, range)): + raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) + elif isinstance(value, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) + + # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column + # key. + index, columns = slice(None), key + index_is_bool_indexer = False + if isinstance(key, slice): + if is_integer(key.start) and is_integer(key.stop): + # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as + # df.iloc[1:2, :] = val + self.iloc[key] = value + return + index, columns = key, slice(None) + elif isinstance(key, pd.Series): + if is_bool_dtype(key.dtype): + index, columns = key, slice(None) + index_is_bool_indexer = True + elif is_bool_indexer(key): + index, columns = pd.Series(key), slice(None) + index_is_bool_indexer = True + + # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case + # we have to explicitly set matching_item_columns_by_label to False for setitem. + index = index._query_compiler if isinstance(index, BasePandasDataset) else index + columns = ( + columns._query_compiler if isinstance(columns, BasePandasDataset) else columns + ) + from snowflake.snowpark.modin.pandas.indexing import is_2d_array + + matching_item_rows_by_label = not is_2d_array(value) + if is_2d_array(value): + value = DataFrame(value) + item = value._query_compiler if isinstance(value, BasePandasDataset) else value + new_qc = self._query_compiler.set_2d_labels( + index, + columns, + item, + # setitem always matches item by position + matching_item_columns_by_label=False, + matching_item_rows_by_label=matching_item_rows_by_label, + index_is_bool_indexer=index_is_bool_indexer, + # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling + # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the + # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have + # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns + # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", + # "X", "X". + deduplicate_columns=True, + ) + return self._update_inplace(new_query_compiler=new_qc) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index df136af1a34..38edb9f7bee 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -960,7 +960,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex: DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None) """ - @datetime_index_not_implemented() def tz_convert(self, tz) -> DatetimeIndex: """ Convert tz-aware Datetime Array/Index from one time zone to another. @@ -1025,8 +1024,14 @@ def tz_convert(self, tz) -> DatetimeIndex: '2014-08-01 09:00:00'], dtype='datetime64[ns]', freq='h') """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_convert( + tz, + include_index=True, + ) + ) - @datetime_index_not_implemented() def tz_localize( self, tz, @@ -1104,21 +1109,29 @@ def tz_localize( Localize DatetimeIndex in US/Eastern time zone: - >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP - >>> tz_aware # doctest: +SKIP - DatetimeIndex(['2018-03-01 09:00:00-05:00', - '2018-03-02 09:00:00-05:00', + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') + >>> tz_aware + DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00', '2018-03-03 09:00:00-05:00'], - dtype='datetime64[ns, US/Eastern]', freq=None) + dtype='datetime64[ns, UTC-05:00]', freq=None) With the ``tz=None``, we can remove the time zone information while keeping the local time (not converted to UTC): - >>> tz_aware.tz_localize(None) # doctest: +SKIP + >>> tz_aware.tz_localize(None) DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', '2018-03-03 09:00:00'], dtype='datetime64[ns]', freq=None) """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_localize( + tz, + ambiguous, + nonexistent, + include_index=True, + ) + ) def round( self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 12710224de7..b25bb481dc0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -30,6 +30,7 @@ import modin import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import get_option from pandas._libs import lib @@ -49,7 +50,6 @@ ) from pandas.core.dtypes.inference import is_hashable -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin @@ -71,6 +71,35 @@ } +class IndexParent: + def __init__(self, parent: DataFrame | Series) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + + Parameters + ---------- + parent : DataFrame or Series + The parent object that the Index is a part of. + """ + assert isinstance(parent, (DataFrame, Series)) + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the query compiler associated with the parent is + different from the original query compiler recorded, i.e., an inplace update has been applied to the parent. + """ + if self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,7 +164,7 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. + # `_parent` keeps track of the parent object that this Index is a part of. index._parent = None return index @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any: ErrorMessage.not_implemented(f"Index.{key} is not yet implemented") raise err + def _set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = IndexParent(parent) + def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) @@ -408,12 +448,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - @property def values(self) -> ArrayLike: """ @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None: if not is_hashable(value): raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) + # Update the name of the parent's index only if an inplace update is performed on + # the parent object, i.e., the parent's current query compiler matches the originally + # recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None: if isinstance(values, Index): values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) + # Update the name of the parent's index only if the parent's current query compiler + # matches the recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py index c5f9e4f6cee..43f9603cfb4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py @@ -9,10 +9,10 @@ import inspect from typing import Any, Iterable, Literal, Optional, Union +from modin.pandas import DataFrame, Series from pandas._typing import IndexLabel from snowflake.snowpark import DataFrame as SnowparkDataFrame -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py index dea98bbb0d3..6d6fb4cd0bd 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py @@ -15,6 +15,7 @@ ) import pandas as native_pd +from modin.pandas import DataFrame from pandas._libs.lib import NoDefault, no_default from pandas._typing import ( CSVEngine, @@ -26,7 +27,6 @@ ) import snowflake.snowpark.modin.pandas as pd -from snowflake.snowpark.modin.pandas import DataFrame from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py index f7bba4c743a..5b245bfdab4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py @@ -181,7 +181,6 @@ def to_pandas( See Also: - :func:`to_pandas ` - - :func:`DataFrame.to_pandas ` Returns: pandas Series diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 5011defa685..625e5b8032a 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -9,22 +9,13 @@ from __future__ import annotations -from typing import ( - IO, - TYPE_CHECKING, - Any, - Callable, - Hashable, - Literal, - Mapping, - Sequence, -) +from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence import modin.pandas as pd import numpy as np import numpy.typing as npt import pandas as native_pd -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas._libs.lib import NoDefault, is_integer, no_default from pandas._typing import ( @@ -73,9 +64,6 @@ validate_int_kwarg, ) -if TYPE_CHECKING: - from modin.pandas import DataFrame - def register_series_not_implemented(): def decorator(base_method: Any): @@ -209,21 +197,6 @@ def hist( pass # pragma: no cover -@register_series_not_implemented() -def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, -): # noqa: PR01, RT01, D200 - pass # pragma: no cover - - @register_series_not_implemented() def item(self): # noqa: RT01, D200 pass # pragma: no cover @@ -419,6 +392,25 @@ def __init__( self.name = name +@register_series_accessor("_update_inplace") +def _update_inplace(self, new_query_compiler) -> None: + """ + Update the current Series in-place using `new_query_compiler`. + + Parameters + ---------- + new_query_compiler : BaseQueryCompiler + QueryCompiler to use to manage the data. + """ + super(Series, self)._update_inplace(new_query_compiler=new_query_compiler) + # Propagate changes back to parent so that column in dataframe had the same contents + if self._parent is not None: + if self._parent_axis == 1 and isinstance(self._parent, DataFrame): + self._parent[self.name] = self + else: + self._parent.loc[self.index] = self + + # Since Snowpark pandas leaves all data on the warehouse, memory_usage's report of local memory # usage isn't meaningful and is set to always return 0. @_inherit_docstrings(native_pd.Series.memory_usage, apilink="pandas.Series") @@ -1451,9 +1443,7 @@ def set_axis( ) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. +# Snowpark pandas does different validation. @register_series_accessor("rename") def rename( self, @@ -1503,9 +1493,36 @@ def rename( return self_cp -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. +# Modin defaults to pandas for some arguments for unstack +@register_series_accessor("unstack") +def unstack( + self, + level: int | str | list = -1, + fill_value: int | str | dict = None, + sort: bool = True, +): + """ + Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from modin.pandas.dataframe import DataFrame + + # We can't unstack a Series object, if we don't have a MultiIndex. + if self._query_compiler.has_multiindex: + result = DataFrame( + query_compiler=self._query_compiler.unstack( + level, fill_value, sort, is_series_input=True + ) + ) + else: + raise ValueError( # pragma: no cover + f"index must be a MultiIndex to unstack, {type(self.index)} was passed" + ) + + return result + + +# Snowpark pandas does an extra check on `len(ascending)`. @register_series_accessor("sort_values") def sort_values( self, @@ -1521,7 +1538,7 @@ def sort_values( Sort by the values. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame + from modin.pandas.dataframe import DataFrame if is_list_like(ascending) and len(ascending) != 1: raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series") @@ -1550,38 +1567,6 @@ def sort_values( return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once -# our vendored copy of DataFrame is removed. -# Modin also defaults to pandas for some arguments for unstack -@register_series_accessor("unstack") -def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, -): - """ - Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - # We can't unstack a Series object, if we don't have a MultiIndex. - if self._query_compiler.has_multiindex: - result = DataFrame( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=True - ) - ) - else: - raise ValueError( # pragma: no cover - f"index must be a MultiIndex to unstack, {type(self.index)} was passed" - ) - - return result - - # Upstream Modin defaults at the frontend layer. @register_series_accessor("where") def where( @@ -1727,63 +1712,6 @@ def to_dict(self, into: type[dict] = dict) -> dict: return self._to_pandas().to_dict(into=into) -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored -# version of DataFrame, we must keep this override. -@register_series_accessor("_create_or_update_from_compiler") -def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a Series with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - Series, DataFrame or None - None if update was done, Series or DataFrame otherwise. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace and new_query_compiler.is_series_like(): - return self.__constructor__(query_compiler=new_query_compiler) - elif not inplace: - # This can happen with things like `reset_index` where we can add columns. - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - return DataFrame(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - -# TODO: SNOW-1063346 -# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored -# version of DataFrame, we must keep this override. -@register_series_accessor("to_frame") -def to_frame(self, name: Hashable = no_default) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Convert Series to {label -> value} dict or dict-like object. - """ - # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame - - if name is None: - name = no_default - - self_cp = self.copy() - if name is not no_default: - self_cp.name = name - - return DataFrame(self_cp) - - @register_series_accessor("to_numpy") def to_numpy( self, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 1dbb743aa32..1cd5e31c63f 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -28,16 +28,11 @@ import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from pandas._libs import lib from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable from pandas.core.dtypes.common import is_timedelta64_dtype -from snowflake.snowpark import functions as fn -from snowflake.snowpark.modin.pandas import DataFrame, Series -from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( - AggregateColumnOpParameters, - aggregate_with_ordered_dataframe, -) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -45,7 +40,6 @@ from snowflake.snowpark.modin.plugin.utils.error_message import ( timedelta_index_not_implemented, ) -from snowflake.snowpark.types import LongType _CONSTRUCTOR_DEFAULTS = { "unit": lib.no_default, @@ -433,19 +427,25 @@ def mean( raise ValueError( f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'" ) - # TODO SNOW-1620439: Reuse code from Series.mean. - frame = self._query_compiler._modin_frame - index_id = frame.index_column_snowflake_quoted_identifiers[0] - new_index_id = frame.ordered_dataframe.generate_snowflake_quoted_identifiers( - pandas_labels=["mean"] - )[0] - agg_column_op_params = AggregateColumnOpParameters( - index_id, LongType(), "mean", new_index_id, fn.mean, [] + pandas_dataframe_result = ( + # reset_index(drop=False) copies the index column of + # self._query_compiler into a new data column. Use `drop=False` + # so that we don't have to use SQL row_number() to generate a new + # index column. + self._query_compiler.reset_index(drop=False) + # Aggregate the data column. + .agg("mean", axis=0, args=(), kwargs={"skipna": skipna}) + # convert the query compiler to a pandas dataframe with + # dimensions 1x1 (note that the frame has a single row even + # if `self` is empty.) + .to_pandas() ) - mean_value = aggregate_with_ordered_dataframe( - frame.ordered_dataframe, [agg_column_op_params], {"skipna": skipna} - ).collect()[0][0] - return native_pd.Timedelta(np.nan if mean_value is None else int(mean_value)) + assert pandas_dataframe_result.shape == ( + 1, + 1, + ), "Internal error: aggregation result is not 1x1." + # Return the only element in the frame. + return pandas_dataframe_result.iloc[0, 0] @timedelta_index_not_implemented() def as_unit(self, unit: str) -> TimedeltaIndex: diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py index 785a492ca89..f3102115a32 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py +++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py @@ -42,3 +42,17 @@ SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE = ( "Scalar key incompatible with {} value" ) + +DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( + "Currently do not support Series or list-like keys with range-like values" +) + +DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( + "Currently do not support assigning a slice value as if it's a scalar value" +) + +DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( + "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " + "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " + "can work on the entire DataFrame in one shot." +) diff --git a/src/snowflake/snowpark/modin/utils.py b/src/snowflake/snowpark/modin/utils.py index b1027f00e33..b3446ca0362 100644 --- a/src/snowflake/snowpark/modin/utils.py +++ b/src/snowflake/snowpark/modin/utils.py @@ -1171,7 +1171,7 @@ def validate_int_kwarg(value: int, arg_name: str, float_allowed: bool = False) - def doc_replace_dataframe_with_link(_obj: Any, doc: str) -> str: """ Helper function to be passed as the `modify_doc` parameter to `_inherit_docstrings`. This replaces - all unqualified instances of "DataFrame" with ":class:`~snowflake.snowpark.pandas.DataFrame`" to + all unqualified instances of "DataFrame" with ":class:`~modin.pandas.DataFrame`" to prevent it from linking automatically to snowflake.snowpark.DataFrame: see SNOW-1233342. To prevent it from overzealously replacing examples in doctests or already-qualified paths, it diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 607cd047f2b..e4177842032 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -224,6 +224,16 @@ _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = ( "PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION" ) +_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND = ( + "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND" +) +_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = ( + "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND" +) +# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT +# in Snowflake. This is the limit where we start seeing compilation errors. +DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 +DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None @@ -580,14 +590,22 @@ def __init__( _PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION, False ) ) + # The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT + # in Snowflake. This is the limit where we start seeing compilation errors. + self._large_query_breakdown_complexity_bounds: Tuple[int, int] = ( + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + ), + self._conn._get_client_side_session_parameter( + _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + ), + ) self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) - self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - if self._auto_clean_up_temp_table_enabled: - self._temp_table_auto_cleaner.start() - _logger.info("Snowpark Session information: %s", self._session_info) def __enter__(self): @@ -626,8 +644,8 @@ def close(self) -> None: raise SnowparkClientExceptionMessages.SERVER_FAILED_CLOSE_SESSION(str(ex)) finally: try: - self._conn.close() self._temp_table_auto_cleaner.stop() + self._conn.close() _logger.info("Closed session: %s", self._session_id) finally: _remove_session(self) @@ -661,10 +679,33 @@ def auto_clean_up_temp_table_enabled(self) -> bool: :meth:`DataFrame.cache_result` in the current session when the DataFrame is no longer referenced (i.e., gets garbage collected). The default value is ``False``. + Example:: + + >>> import gc + >>> + >>> def f(session: Session) -> str: + ... df = session.create_dataframe( + ... [[1, 2], [3, 4]], schema=["a", "b"] + ... ).cache_result() + ... return df.table_name + ... + >>> session.auto_clean_up_temp_table_enabled = True + >>> table_name = f(session) + >>> assert table_name + >>> gc.collect() # doctest: +SKIP + >>> + >>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced + >>> # outside the function + >>> session.sql(f"show tables like '{table_name}'").count() + 0 + + >>> session.auto_clean_up_temp_table_enabled = False + Note: - Even if this parameter is ``False``, Snowpark still records temporary tables when - their corresponding DataFrame are garbage collected. Therefore, if you turn it on in the middle of your session or after turning it off, - the target temporary tables will still be cleaned up accordingly. + Temporary tables will only be dropped if this parameter is enabled during garbage collection. + If a temporary table is no longer referenced when the parameter is on, it will be dropped during garbage collection. + However, if garbage collection occurs while the parameter is off, the table will not be removed. + Note that Python's garbage collection is triggered opportunistically, with no guaranteed timing. """ return self._auto_clean_up_temp_table_enabled @@ -672,6 +713,10 @@ def auto_clean_up_temp_table_enabled(self) -> bool: def large_query_breakdown_enabled(self) -> bool: return self._large_query_breakdown_enabled + @property + def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]: + return self._large_query_breakdown_complexity_bounds + @property def custom_package_usage_config(self) -> Dict: """Get or set configuration parameters related to usage of custom Python packages in Snowflake. @@ -758,11 +803,6 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None: self._session_id, value ) self._auto_clean_up_temp_table_enabled = value - is_alive = self._temp_table_auto_cleaner.is_alive() - if value and not is_alive: - self._temp_table_auto_cleaner.start() - elif not value and is_alive: - self._temp_table_auto_cleaner.stop() else: raise ValueError( "value for auto_clean_up_temp_table_enabled must be True or False!" @@ -787,6 +827,24 @@ def large_query_breakdown_enabled(self, value: bool) -> None: "value for large_query_breakdown_enabled must be True or False!" ) + @large_query_breakdown_complexity_bounds.setter + def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: + """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" + + if len(value) != 2: + raise ValueError( + f"Expecting a tuple of two integers. Got a tuple of length {len(value)}" + ) + if value[0] >= value[1]: + raise ValueError( + f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})" + ) + self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds( + self._session_id, value[0], value[1] + ) + + self._large_query_breakdown_complexity_bounds = value + @custom_package_usage_config.setter @experimental_parameter(version="1.6.0") def custom_package_usage_config(self, config: Dict) -> None: @@ -1654,8 +1712,8 @@ def _upload_unsupported_packages( try: # Setup a temporary directory and target folder where pip install will take place. - self._tmpdir_handler = tempfile.TemporaryDirectory() - tmpdir = self._tmpdir_handler.name + tmpdir_handler = tempfile.TemporaryDirectory() + tmpdir = tmpdir_handler.name target = os.path.join(tmpdir, "unsupported_packages") if not os.path.exists(target): os.makedirs(target) @@ -1740,9 +1798,7 @@ def _upload_unsupported_packages( for requirement in supported_dependencies + new_dependencies ] ) - metadata_local_path = os.path.join( - self._tmpdir_handler.name, metadata_file - ) + metadata_local_path = os.path.join(tmpdir_handler.name, metadata_file) with open(metadata_local_path, "w") as file: for key, value in metadata.items(): file.write(f"{key},{value}\n") @@ -1778,9 +1834,8 @@ def _upload_unsupported_packages( f"-third-party-packages-from-anaconda-in-a-udf." ) finally: - if self._tmpdir_handler: - self._tmpdir_handler.cleanup() - self._tmpdir_handler = None + if tmpdir_handler: + tmpdir_handler.cleanup() return supported_dependencies + new_dependencies diff --git a/tests/integ/modin/conftest.py b/tests/integ/modin/conftest.py index 2f24954e769..a7217b38a50 100644 --- a/tests/integ/modin/conftest.py +++ b/tests/integ/modin/conftest.py @@ -715,3 +715,30 @@ def numeric_test_data_4x4(): "C": [7, 10, 13, 16], "D": [8, 11, 14, 17], } + + +@pytest.fixture +def timedelta_native_df() -> pandas.DataFrame: + return pandas.DataFrame( + { + "A": [ + pd.Timedelta(days=1), + pd.Timedelta(days=2), + pd.Timedelta(days=3), + pd.Timedelta(days=4), + ], + "B": [ + pd.Timedelta(minutes=-1), + pd.Timedelta(minutes=0), + pd.Timedelta(minutes=5), + pd.Timedelta(minutes=6), + ], + "C": [ + None, + pd.Timedelta(nanoseconds=5), + pd.Timedelta(nanoseconds=0), + pd.Timedelta(nanoseconds=4), + ], + "D": pandas.to_timedelta([pd.NaT] * 4), + } + ) diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index b018682b6f8..ba68ae13734 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -187,6 +187,108 @@ def test_string_sum_with_nulls(): assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"])) +class TestTimedelta: + """Test aggregating dataframes containing timedelta columns.""" + + @pytest.mark.parametrize( + "func, union_count", + [ + param( + lambda df: df.aggregate(["min"]), + 0, + id="aggregate_list_with_one_element", + ), + param(lambda df: df.aggregate(x=("A", "max")), 0, id="single_named_agg"), + # this works since all results are timedelta and we don't need to do any concats. + param( + lambda df: df.aggregate({"B": "mean", "A": "sum"}), + 0, + id="dict_producing_two_timedeltas", + ), + # this works since even though we need to do concats, all the results are non-timdelta. + param( + lambda df: df.aggregate(x=("B", "all"), y=("B", "any")), + 1, + id="named_agg_producing_two_bools", + ), + # note following aggregation requires transpose + param(lambda df: df.aggregate(max), 0, id="aggregate_max"), + param(lambda df: df.min(), 0, id="min"), + param(lambda df: df.max(), 0, id="max"), + param(lambda df: df.count(), 0, id="count"), + param(lambda df: df.sum(), 0, id="sum"), + param(lambda df: df.mean(), 0, id="mean"), + param(lambda df: df.median(), 0, id="median"), + param(lambda df: df.std(), 0, id="std"), + param(lambda df: df.quantile(), 0, id="single_quantile"), + param(lambda df: df.quantile([0.01, 0.99]), 1, id="two_quantiles"), + ], + ) + def test_supported_axis_0(self, func, union_count, timedelta_native_df): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + func, + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1653126") + def test_axis_1(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), lambda df: df.sum(axis=1) + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}), + lambda df: df.aggregate({"B": ["count"], "A": "sum", "C": ["max", "min"]}), + lambda df: df.aggregate( + x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count") + ), + lambda df: df.aggregate(["min", np.max]), + lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")), + lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")), + lambda df: df.aggregate( + {"B": ["idxmax"], "A": "sum", "C": ["max", "idxmin"]} + ), + ], + ) + def test_agg_requires_concat_with_timedelta(self, timedelta_native_df, operation): + eval_snowpark_pandas_result(*create_test_dfs(timedelta_native_df), operation) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires transposing a one-row frame with integer and timedelta.", + ) + def test_agg_produces_timedelta_and_non_timedelta_type(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.aggregate({"B": "idxmax", "A": "sum"}), + ) + + @pytest.mark.parametrize( "func, expected_union_count", [ diff --git a/tests/integ/modin/frame/test_describe.py b/tests/integ/modin/frame/test_describe.py index a9668c5794f..4f1882d441d 100644 --- a/tests/integ/modin/frame/test_describe.py +++ b/tests/integ/modin/frame/test_describe.py @@ -358,3 +358,18 @@ def test_describe_object_file(resources_path): df = pd.read_csv(test_files.test_concat_file1_csv) native_df = df.to_pandas() eval_snowpark_pandas_result(df, native_df, lambda x: x.describe(include="O")) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df.describe(), + ) diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py index 72fe88968bc..87041060bd2 100644 --- a/tests/integ/modin/frame/test_idxmax_idxmin.py +++ b/tests/integ/modin/frame/test_idxmax_idxmin.py @@ -196,8 +196,18 @@ def test_idxmax_idxmin_with_dates(func, axis): @sql_count_checker(query_count=1) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) -@pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.xfail(reason="SNOW-1625380 TODO") +@pytest.mark.parametrize( + "axis", + [ + 0, + pytest.param( + 1, + marks=pytest.mark.xfail( + strict=True, raises=NotImplementedError, reason="SNOW-1653126" + ), + ), + ], +) def test_idxmax_idxmin_with_timedelta(func, axis): native_df = native_pd.DataFrame( data={ diff --git a/tests/integ/modin/frame/test_info.py b/tests/integ/modin/frame/test_info.py index 2a096e76fdc..fbbf8dfe041 100644 --- a/tests/integ/modin/frame/test_info.py +++ b/tests/integ/modin/frame/test_info.py @@ -13,9 +13,7 @@ def _assert_info_lines_equal(modin_info: list[str], pandas_info: list[str]): # class is different - assert ( - modin_info[0] == "" - ) + assert modin_info[0] == "" assert pandas_info[0] == "" # index is different diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py index be51b8c9ae6..105bf475f3a 100644 --- a/tests/integ/modin/frame/test_loc.py +++ b/tests/integ/modin/frame/test_loc.py @@ -4072,3 +4072,22 @@ def test_df_loc_get_with_timedelta_and_none_key(): # Compare with an empty DataFrame, since native pandas raises a KeyError. expected_df = native_pd.DataFrame() assert_frame_equal(snow_df.loc[None], expected_df, check_column_type=False) + + +@sql_count_checker(query_count=0) +def test_df_loc_invalid_key(): + # Bug fix: SNOW-1320674 + native_df = native_pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + snow_df = pd.DataFrame(native_df) + + def op(df): + df["C"] = df["A"] / df["D"] + + eval_snowpark_pandas_result( + snow_df, + native_df, + op, + expect_exception=True, + expect_exception_type=KeyError, + expect_exception_match="D", + ) diff --git a/tests/integ/modin/frame/test_nunique.py b/tests/integ/modin/frame/test_nunique.py index d0cad8ec2ad..78098d34386 100644 --- a/tests/integ/modin/frame/test_nunique.py +++ b/tests/integ/modin/frame/test_nunique.py @@ -11,8 +11,13 @@ from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result -TEST_LABELS = np.array(["A", "B", "C", "D"]) -TEST_DATA = [[0, 1, 2, 3], [0, 0, 0, 0], [None, 0, None, 0], [None, None, None, None]] +TEST_LABELS = np.array(["A", "B", "C", "D", "E"]) +TEST_DATA = [ + [0, 1, 2, 3, pd.Timedelta(4)], + [0, 0, 0, 0, pd.Timedelta(0)], + [None, 0, None, 0, pd.Timedelta(0)], + [None, None, None, None, None], +] # which original dataframe (constructed from slicing) to test for TEST_SLICES = [ @@ -80,7 +85,7 @@ def test_dataframe_nunique_no_columns(native_df): [ pytest.param(None, id="default_columns"), pytest.param( - [["bar", "bar", "baz", "foo"], ["one", "two", "one", "two"]], + [["bar", "bar", "baz", "foo", "foo"], ["one", "two", "one", "two", "one"]], id="2D_columns", ), ], diff --git a/tests/integ/modin/frame/test_skew.py b/tests/integ/modin/frame/test_skew.py index 72fad6cebdc..94b7fd79c24 100644 --- a/tests/integ/modin/frame/test_skew.py +++ b/tests/integ/modin/frame/test_skew.py @@ -8,7 +8,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import assert_series_equal +from tests.integ.modin.utils import ( + assert_series_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @sql_count_checker(query_count=1) @@ -62,16 +66,22 @@ def test_skew_basic(): }, "kwargs": {"numeric_only": True, "skipna": True}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": True, + }, + }, ], ) @sql_count_checker(query_count=1) def test_skew(data): - native_df = native_pd.DataFrame(data["frame"]) - snow_df = pd.DataFrame(native_df) - assert_series_equal( - snow_df.skew(**data["kwargs"]), - native_df.skew(**data["kwargs"]), - rtol=1.0e-5, + eval_snowpark_pandas_result( + *create_test_dfs(data["frame"]), + lambda df: df.skew(**data["kwargs"]), + rtol=1.0e-5 ) @@ -103,6 +113,14 @@ def test_skew(data): }, "kwargs": {"level": 2}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": False, + }, + }, ], ) @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py index d5234dfbdb5..df8df44d47c 100644 --- a/tests/integ/modin/groupby/test_all_any.py +++ b/tests/integ/modin/groupby/test_all_any.py @@ -14,7 +14,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_frame_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @pytest.mark.parametrize( @@ -109,3 +113,27 @@ def test_all_any_chained(): lambda df: df.apply(lambda ser: ser.str.len()) ) ) + + +@sql_count_checker(query_count=1) +def test_timedelta_any_with_nulls(): + """ + Test this case separately because pandas behavior is different from Snowpark pandas behavior. + + pandas bug that does not apply to Snowpark pandas: + https://github.com/pandas-dev/pandas/issues/59712 + """ + snow_df, native_df = create_test_dfs( + { + "key": ["a"], + "A": native_pd.Series([pd.NaT], dtype="timedelta64[ns]"), + }, + ) + assert_frame_equal( + native_df.groupby("key").any(), + native_pd.DataFrame({"A": [True]}, index=native_pd.Index(["a"], name="key")), + ) + assert_frame_equal( + snow_df.groupby("key").any(), + native_pd.DataFrame({"A": [False]}, index=native_pd.Index(["a"], name="key")), + ) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 09acd49bb21..cbf5b75d48c 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1096,60 +1096,81 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): ) -@pytest.mark.parametrize( - "agg_func", - [ - "count", - "sum", - "mean", - "median", - "std", - ], -) -@pytest.mark.parametrize("by", ["A", "B"]) -@sql_count_checker(query_count=1) -def test_timedelta(agg_func, by): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - - eval_snowpark_pandas_result( - snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() - ) - - -def test_timedelta_groupby_agg(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - "C": [True, False, False, True], - } +class TestTimedelta: + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "method", + [ + "count", + "mean", + "min", + "max", + "idxmax", + "idxmin", + "sum", + "median", + "std", + "nunique", + ], ) - snow_df = pd.DataFrame(native_df) - with SqlCounter(query_count=1): + @pytest.mark.parametrize("by", ["A", "B"]) + def test_aggregation_methods(self, method, by): eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: getattr(df.groupby(by), method)(), ) - with SqlCounter(query_count=1): - eval_snowpark_pandas_result( - snow_df, - native_df, + + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}), + lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + lambda df: df.groupby("B").agg(["mean", "std"]), + lambda df: df.groupby("B").agg({"A": ["count", np.sum]}), + lambda df: df.groupby("B").agg({"A": "sum"}), + ], + ) + def test_agg(self, operation): + eval_snowpark_pandas_result( + *create_test_dfs( + native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + "C": [True, False, False, True], + } + ) + ), + operation, ) - with SqlCounter(query_count=1): + + @sql_count_checker(query_count=1) + def test_groupby_timedelta_var(self): + """ + Test that we can group by a timedelta column and take var() of an integer column. + + Note that we can't take the groupby().var() of the timedelta column because + var() is not defined for timedelta, in pandas or in Snowpark pandas. + """ eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: df.groupby("A").var(), ) diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py index 5da35806dd1..5e04d5a6fc2 100644 --- a/tests/integ/modin/groupby/test_groupby_first_last.py +++ b/tests/integ/modin/groupby/test_groupby_first_last.py @@ -46,6 +46,17 @@ [np.nan], ] ), + "col11_timedelta": [ + pd.Timedelta("1 days"), + None, + pd.Timedelta("2 days"), + None, + None, + None, + None, + None, + None, + ], } diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py index a009e1089b0..0c9c056c2a7 100644 --- a/tests/integ/modin/groupby/test_groupby_negative.py +++ b/tests/integ/modin/groupby/test_groupby_negative.py @@ -18,6 +18,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + create_test_dfs, eval_snowpark_pandas_result, ) @@ -559,20 +560,12 @@ def test_groupby_agg_invalid_min_count( @sql_count_checker(query_count=0) -def test_groupby_var_no_support_for_timedelta(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - with pytest.raises( - NotImplementedError, - match=re.escape( - "SnowflakeQueryCompiler::groupby_agg is not yet implemented for Timedelta Type" +def test_timedelta_var_invalid(): + eval_snowpark_pandas_result( + *create_test_dfs( + [["key0", pd.Timedelta(1)]], ), - ): - snow_df.groupby("B").var() + lambda df: df.groupby(0).var(), + expect_exception=True, + expect_exception_type=TypeError, + ) diff --git a/tests/integ/modin/groupby/test_quantile.py b/tests/integ/modin/groupby/test_quantile.py index b14299fee63..940d366a7e2 100644 --- a/tests/integ/modin/groupby/test_quantile.py +++ b/tests/integ/modin/groupby/test_quantile.py @@ -64,6 +64,14 @@ # ), # All NA ([np.nan] * 5, [np.nan] * 5), + pytest.param( + pd.timedelta_range( + "1 days", + "5 days", + ), + pd.timedelta_range("1 second", "5 second"), + id="timedelta", + ), ], ) @pytest.mark.parametrize("q", [0, 0.5, 1]) diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index 84454fc4a27..26afd232c4f 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -79,4 +79,5 @@ tz="America/Los_Angeles", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), + native_pd.TimedeltaIndex(["4 days", None, "-1 days", "5 days"]), ] diff --git a/tests/integ/modin/index/test_all_any.py b/tests/integ/modin/index/test_all_any.py index 267e7929ea1..499be6f03dc 100644 --- a/tests/integ/modin/index/test_all_any.py +++ b/tests/integ/modin/index/test_all_any.py @@ -25,6 +25,9 @@ native_pd.Index(["a", "b", "c", "d"]), native_pd.Index([5, None, 7]), native_pd.Index([], dtype="object"), + native_pd.Index([pd.Timedelta(0), None]), + native_pd.Index([pd.Timedelta(0)]), + native_pd.Index([pd.Timedelta(0), pd.Timedelta(1)]), ] NATIVE_INDEX_EMPTY_DATA = [ diff --git a/tests/integ/modin/index/test_argmax_argmin.py b/tests/integ/modin/index/test_argmax_argmin.py index 6d446a0a66a..7d42f3b88c9 100644 --- a/tests/integ/modin/index/test_argmax_argmin.py +++ b/tests/integ/modin/index/test_argmax_argmin.py @@ -18,6 +18,18 @@ native_pd.Index([4, None, 1, 3, 4, 1]), native_pd.Index([4, None, 1, 3, 4, 1], name="some name"), native_pd.Index([1, 10, 4, 3, 4]), + pytest.param( + native_pd.Index( + [ + pd.Timedelta(1), + pd.Timedelta(10), + pd.Timedelta(4), + pd.Timedelta(3), + pd.Timedelta(4), + ] + ), + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 143e1d74080..98d1a041c3b 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -7,6 +7,7 @@ import numpy as np import pandas as native_pd import pytest +import pytz import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker @@ -17,6 +18,46 @@ eval_snowpark_pandas_result, ) +timezones = pytest.mark.parametrize( + "tz", + [ + None, + # Use a subset of pytz.common_timezones containing a few timezones in each + *[ + param_for_one_tz + for tz in [ + "Africa/Abidjan", + "Africa/Timbuktu", + "America/Adak", + "America/Yellowknife", + "Antarctica/Casey", + "Asia/Dhaka", + "Asia/Manila", + "Asia/Shanghai", + "Atlantic/Stanley", + "Australia/Sydney", + "Canada/Pacific", + "Europe/Chisinau", + "Europe/Luxembourg", + "Indian/Christmas", + "Pacific/Chatham", + "Pacific/Wake", + "US/Arizona", + "US/Central", + "US/Eastern", + "US/Hawaii", + "US/Mountain", + "US/Pacific", + "UTC", + ] + for param_for_one_tz in ( + pytz.timezone(tz), + tz, + ) + ], + ], +) + @sql_count_checker(query_count=0) def test_datetime_index_construction(): @@ -101,13 +142,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame({"A": [1]}, index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) @@ -233,6 +274,76 @@ def test_normalize(): ) +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_convert(tz): + native_index = native_pd.date_range( + start="2021-01-01", periods=5, freq="7h", tz="US/Eastern" + ) + native_index = native_index.append( + native_pd.DatetimeIndex([pd.NaT], tz="US/Eastern") + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_convert(tz).equals( + pd.DatetimeIndex(native_index.tz_convert(tz)) + ) + + +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_localize(tz): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_localize(tz).equals( + pd.DatetimeIndex(native_index.tz_localize(tz)) + ) + + +@pytest.mark.parametrize( + "ambiguous, nonexistent", + [ + ("infer", "raise"), + ("NaT", "raise"), + (np.array([True, True, False]), "raise"), + ("raise", "shift_forward"), + ("raise", "shift_backward"), + ("raise", "NaT"), + ("raise", pd.Timedelta("1h")), + ("infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_tz_localize_negative(ambiguous, nonexistent): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + with pytest.raises(NotImplementedError): + snow_index.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent) + + @pytest.mark.parametrize( "datetime_index_value", [ diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 8d0434915ac..6b33eb89889 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -393,13 +393,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py index b916110f386..f915598c5f6 100644 --- a/tests/integ/modin/index/test_name.py +++ b/tests/integ/modin/index/test_name.py @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index(): ), inplace=True, ) + + +@sql_count_checker(query_count=1) +def test_index_names_replace_behavior(): + """ + Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change the names. + snow_index.name = "test2" + native_index.name = "test2" + + # Compare the names. + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the query compiler the DataFrame is referring to, change the names. + snow_df.dropna(inplace=True) + native_df.dropna(inplace=True) + snow_index.name = "test3" + native_index.name = "test3" + + # Compare the names. Changing the index name should not change the DataFrame's index name. + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test2" + + +@sql_count_checker(query_count=1) +def test_index_names_multiple_renames(): + """ + Check that the index name of a DataFrame can be renamed any number of times. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change and compare the names. + snow_index.name = "test2" + native_index.name = "test2" + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the names again and compare. + snow_index.name = "test3" + native_index.name = "test3" + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test3" diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py index fa354fda1fc..c3e40828d94 100644 --- a/tests/integ/modin/series/test_aggregate.py +++ b/tests/integ/modin/series/test_aggregate.py @@ -1,6 +1,8 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import re + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -17,6 +19,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_series, eval_snowpark_pandas_result, ) @@ -358,3 +361,67 @@ def test_2_tuple_named_agg_errors_for_series(native_series, agg_kwargs): expect_exception_type=SpecificationError, assert_exception_equal=True, ) + + +class TestTimedelta: + """Test aggregating a timedelta series.""" + + @pytest.mark.parametrize( + "func, union_count, is_scalar", + [ + pytest.param(*v, id=str(i)) + for i, v in enumerate( + [ + (lambda series: series.aggregate(["min"]), 0, False), + (lambda series: series.aggregate({"A": "max"}), 0, False), + # this works since even though we need to do concats, all the results are non-timdelta. + (lambda df: df.aggregate(["all", "any", "count"]), 2, False), + # note following aggregation requires transpose + (lambda df: df.aggregate(max), 0, True), + (lambda df: df.min(), 0, True), + (lambda df: df.max(), 0, True), + (lambda df: df.count(), 0, True), + (lambda df: df.sum(), 0, True), + (lambda df: df.mean(), 0, True), + (lambda df: df.median(), 0, True), + (lambda df: df.std(), 0, True), + (lambda df: df.quantile(), 0, True), + (lambda df: df.quantile([0.01, 0.99]), 0, False), + ] + ) + ], + ) + def test_supported(self, func, union_count, timedelta_native_df, is_scalar): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + func, + comparator=validate_scalar_result + if is_scalar + else assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda series: series.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + def test_unsupported_due_to_concat(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda df: df.agg(["count", "max"]), + ) diff --git a/tests/integ/modin/series/test_argmax_argmin.py b/tests/integ/modin/series/test_argmax_argmin.py index 607b36a27f3..e212e3ba2dd 100644 --- a/tests/integ/modin/series/test_argmax_argmin.py +++ b/tests/integ/modin/series/test_argmax_argmin.py @@ -18,6 +18,11 @@ ([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]), ([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/series/test_describe.py b/tests/integ/modin/series/test_describe.py index 9ecd2e33a3d..0f7bbda6c3a 100644 --- a/tests/integ/modin/series/test_describe.py +++ b/tests/integ/modin/series/test_describe.py @@ -11,6 +11,7 @@ from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( assert_series_equal, + create_test_dfs, create_test_series, eval_snowpark_pandas_result, ) @@ -156,3 +157,18 @@ def test_describe_multiindex(data, index): eval_snowpark_pandas_result( *create_test_series(data, index=index), lambda ser: ser.describe() ) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df["A"].describe(), + ) diff --git a/tests/integ/modin/series/test_fillna.py b/tests/integ/modin/series/test_fillna.py index 9371cd0dcd1..80997070b92 100644 --- a/tests/integ/modin/series/test_fillna.py +++ b/tests/integ/modin/series/test_fillna.py @@ -3,6 +3,8 @@ # +import string + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -201,3 +203,17 @@ def inplace_fillna(df): native_pd.DataFrame([[1, 2, 3], [4, None, 6]], columns=list("ABC")), inplace_fillna, ) + + +@pytest.mark.parametrize("index", [list(range(8)), list(string.ascii_lowercase[:8])]) +@sql_count_checker(query_count=1, join_count=4) +def test_inplace_fillna_from_series(index): + def inplace_fillna(series): + series.iloc[:4].fillna(14, inplace=True) + return series + + eval_snowpark_pandas_result( + pd.Series([np.nan, 1, 2, 3, 4, 5, 6, 7], index=index), + native_pd.Series([np.nan, 1, 2, 3, 4, 5, 6, 7], index=index), + inplace_fillna, + ) diff --git a/tests/integ/modin/series/test_first_last_valid_index.py b/tests/integ/modin/series/test_first_last_valid_index.py index 1e8d052e10f..1930bdf1088 100644 --- a/tests/integ/modin/series/test_first_last_valid_index.py +++ b/tests/integ/modin/series/test_first_last_valid_index.py @@ -22,6 +22,10 @@ native_pd.Series([5, 6, 7, 8], index=["i", "am", "iron", "man"]), native_pd.Series([None, None, 2], index=[None, 1, 2]), native_pd.Series([None, None, 2], index=[None, None, None]), + pytest.param( + native_pd.Series([None, None, pd.Timedelta(2)], index=[None, 1, 2]), + id="timedelta", + ), ], ) def test_first_and_last_valid_index_series(native_series): diff --git a/tests/integ/modin/series/test_idxmax_idxmin.py b/tests/integ/modin/series/test_idxmax_idxmin.py index ea536240a42..e8e66a30f61 100644 --- a/tests/integ/modin/series/test_idxmax_idxmin.py +++ b/tests/integ/modin/series/test_idxmax_idxmin.py @@ -17,6 +17,11 @@ ([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]), ([1, None, 4, 3, 4], [None, "B", "C", "D", "E"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) diff --git a/tests/integ/modin/series/test_nunique.py b/tests/integ/modin/series/test_nunique.py index bb20e9e4a53..3856dbc516a 100644 --- a/tests/integ/modin/series/test_nunique.py +++ b/tests/integ/modin/series/test_nunique.py @@ -6,6 +6,7 @@ import numpy as np import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker @@ -32,6 +33,20 @@ [True, None, False, True, None], [1.1, "a", None] * 4, [native_pd.to_datetime("2023-12-01"), native_pd.to_datetime("1999-09-09")] * 2, + param( + [ + native_pd.Timedelta(1), + native_pd.Timedelta(1), + native_pd.Timedelta(2), + None, + None, + ], + id="timedelta_with_nulls", + ), + param( + [native_pd.Timedelta(1), native_pd.Timedelta(1), native_pd.Timedelta(2)], + id="timedelta_without_nulls", + ), ], ) @pytest.mark.parametrize("dropna", [True, False]) diff --git a/tests/integ/modin/test_classes.py b/tests/integ/modin/test_classes.py index c92bb85c531..6e6c2eda8eb 100644 --- a/tests/integ/modin/test_classes.py +++ b/tests/integ/modin/test_classes.py @@ -34,14 +34,14 @@ def test_class_names_constructors(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) s = pd.Series(index=[1, 2, 3], data=[3, 2, 1]) expect_type_check( s, pd.Series, - "snowflake.snowpark.modin.pandas.series.Series", + "modin.pandas.series.Series", ) @@ -63,7 +63,7 @@ def test_op(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) @@ -77,7 +77,7 @@ def test_native_conversion(): expect_type_check( df, pd.DataFrame, - "snowflake.snowpark.modin.pandas.dataframe.DataFrame", + "modin.pandas.dataframe.DataFrame", ) # Snowpark pandas -> native pandas diff --git a/tests/integ/modin/test_merge_asof.py b/tests/integ/modin/test_merge_asof.py index 681d339da90..5aab91fc9cb 100644 --- a/tests/integ/modin/test_merge_asof.py +++ b/tests/integ/modin/test_merge_asof.py @@ -105,6 +105,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.072"), pd.Timestamp("2016-05-25 13:30:00.075"), ], + "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], } @@ -118,6 +119,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.048"), pd.Timestamp("2016-05-25 13:30:00.048"), ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], "price": [51.95, 51.95, 720.77, 720.92, 98.0], "quantity": [75, 155, 100, 100, 100], } @@ -229,14 +231,70 @@ def test_merge_asof_left_right_on( assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) +@allow_exact_matches +@direction +@sql_count_checker(query_count=1, join_count=1) +def test_merge_asof_left_right_index(allow_exact_matches, direction): + native_left = native_pd.DataFrame({"left_val": ["a", "b", "c"]}, index=[1, 5, 10]) + native_right = native_pd.DataFrame( + {"right_val": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7] + ) + + snow_left = pd.DataFrame(native_left) + snow_right = pd.DataFrame(native_right) + + native_output = native_pd.merge_asof( + native_left, + native_right, + left_index=True, + right_index=True, + direction=direction, + allow_exact_matches=allow_exact_matches, + ) + snow_output = pd.merge_asof( + snow_left, + snow_right, + left_index=True, + right_index=True, + direction=direction, + allow_exact_matches=allow_exact_matches, + ) + assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) + + +@pytest.mark.parametrize("by", ["ticker", ["ticker"]]) +@sql_count_checker(query_count=1, join_count=1) +def test_merge_asof_by(left_right_timestamp_data, by): + left_native_df, right_native_df = left_right_timestamp_data + left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( + right_native_df + ) + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by=by + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by=by) + assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) + + +@pytest.mark.parametrize( + "left_by, right_by", + [ + ("ticker", "ticker"), + (["ticker", "bid"], ["ticker", "price"]), + ], +) @sql_count_checker(query_count=1, join_count=1) -def test_merge_asof_timestamps(left_right_timestamp_data): +def test_merge_asof_left_right_by(left_right_timestamp_data, left_by, right_by): left_native_df, right_native_df = left_right_timestamp_data left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", left_by=left_by, right_by=right_by + ) + snow_output = pd.merge_asof( + left_snow_df, right_snow_df, on="time", left_by=left_by, right_by=right_by + ) assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -248,8 +306,10 @@ def test_merge_asof_date(left_right_timestamp_data): left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by="ticker" + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by="ticker") assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -360,9 +420,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): with pytest.raises( NotImplementedError, match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ), ): pd.merge_asof( @@ -372,28 +430,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" - ), - ): - pd.merge_asof( - left_snow_df, right_snow_df, on="time", left_by="price", right_by="quantity" - ) - with pytest.raises( - NotImplementedError, - match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" - ), - ): - pd.merge_asof(left_snow_df, right_snow_df, left_index=True, right_index=True) - with pytest.raises( - NotImplementedError, - match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'suffixes', or 'tolerance'" ), ): pd.merge_asof( @@ -406,8 +443,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'suffixes', or 'tolerance'" ), ): pd.merge_asof( diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py index ce9e1caf328..a36298af251 100644 --- a/tests/integ/modin/test_telemetry.py +++ b/tests/integ/modin/test_telemetry.py @@ -110,7 +110,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name): df1_expected_api_calls = [ {"name": "TestClass.test_func"}, - {"name": "DataFrame.DataFrame.dropna", "argument": ["inplace"]}, + {"name": "DataFrame.dropna", "argument": ["inplace"]}, ] assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls @@ -121,7 +121,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name): assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls df2_expected_api_calls = df1_expected_api_calls + [ { - "name": "DataFrame.DataFrame.dropna", + "name": "DataFrame.dropna", }, ] assert df2._query_compiler.snowpark_pandas_api_calls == df2_expected_api_calls @@ -336,10 +336,7 @@ def test_telemetry_with_update_inplace(): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) df.insert(1, "newcol", [99, 99, 90]) assert len(df._query_compiler.snowpark_pandas_api_calls) == 1 - assert ( - df._query_compiler.snowpark_pandas_api_calls[0]["name"] - == "DataFrame.DataFrame.insert" - ) + assert df._query_compiler.snowpark_pandas_api_calls[0]["name"] == "DataFrame.insert" @sql_count_checker(query_count=1) @@ -403,8 +400,8 @@ def test_telemetry_getitem_setitem(): df["a"] = 0 df["b"] = 0 assert df._query_compiler.snowpark_pandas_api_calls == [ - {"name": "DataFrame.DataFrame.__setitem__"}, - {"name": "DataFrame.DataFrame.__setitem__"}, + {"name": "DataFrame.__setitem__"}, + {"name": "DataFrame.__setitem__"}, ] # Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction. s._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch() @@ -422,13 +419,17 @@ def test_telemetry_getitem_setitem(): @pytest.mark.parametrize( - "name, method, expected_query_count", + "name, expected_func_name, method, expected_query_count", [ - ["__repr__", lambda df: df.__repr__(), 1], - ["__iter__", lambda df: df.__iter__(), 0], + # __repr__ is an extension method, so the class name is shown only once. + ["__repr__", "DataFrame.__repr__", lambda df: df.__repr__(), 1], + # __iter__ was defined on the DataFrame class, so it is shown twice. + ["__iter__", "DataFrame.DataFrame.__iter__", lambda df: df.__iter__(), 0], ], ) -def test_telemetry_private_method(name, method, expected_query_count): +def test_telemetry_private_method( + name, expected_func_name, method, expected_query_count +): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) # Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction. df._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch() @@ -439,10 +440,10 @@ def test_telemetry_private_method(name, method, expected_query_count): # the telemetry log from the connector to validate data = _extract_snowpark_pandas_telemetry_log_data( - expected_func_name=f"DataFrame.DataFrame.{name}", + expected_func_name=expected_func_name, session=df._query_compiler._modin_frame.ordered_dataframe.session, ) - assert data["api_calls"] == [{"name": f"DataFrame.DataFrame.{name}"}] + assert data["api_calls"] == [{"name": expected_func_name}] @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py index 4c72df42bba..d28362374ce 100644 --- a/tests/integ/modin/types/test_timedelta.py +++ b/tests/integ/modin/types/test_timedelta.py @@ -2,10 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import datetime +import warnings import modin.pandas as pd import pandas as native_pd import pytest +from pandas.errors import SettingWithCopyWarning from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( @@ -107,3 +109,10 @@ def test_timedelta_not_supported(): match="SnowflakeQueryCompiler::groupby_groups is not yet implemented for Timedelta Type", ): df.groupby("a").groups() + + +@sql_count_checker(query_count=1) +def test_aggregation_does_not_print_internal_warning_SNOW_1664064(): + with warnings.catch_warnings(): + warnings.simplefilter(category=SettingWithCopyWarning, action="error") + pd.Series(pd.Timedelta(1)).max() diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index e42a504a976..bdd780ea69e 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -9,9 +9,13 @@ import pytest from snowflake.snowpark._internal.analyzer import analyzer -from snowflake.snowpark._internal.compiler import large_query_breakdown from snowflake.snowpark.functions import col, lit, sum_distinct, when_matched from snowflake.snowpark.row import Row +from snowflake.snowpark.session import ( + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + Session, +) from tests.utils import Utils pytestmark = [ @@ -22,9 +26,6 @@ ) ] -DEFAULT_LOWER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND -DEFAULT_UPPER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND - @pytest.fixture(autouse=True) def large_query_df(session): @@ -50,20 +51,24 @@ def setup(session): is_query_compilation_stage_enabled = session._query_compilation_stage_enabled session._query_compilation_stage_enabled = True session._large_query_breakdown_enabled = True + set_bounds(session, 300, 600) yield session._query_compilation_stage_enabled = is_query_compilation_stage_enabled session._cte_optimization_enabled = cte_optimization_enabled session._large_query_breakdown_enabled = large_query_breakdown_enabled - reset_bounds() + reset_bounds(session) -def set_bounds(lower_bound: int, upper_bound: int): - large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND = lower_bound - large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND = upper_bound +def set_bounds(session: Session, lower_bound: int, upper_bound: int): + session._large_query_breakdown_complexity_bounds = (lower_bound, upper_bound) -def reset_bounds(): - set_bounds(DEFAULT_LOWER_BOUND, DEFAULT_UPPER_BOUND) +def reset_bounds(session: Session): + set_bounds( + session, + DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND, + DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, + ) def check_result_with_and_without_breakdown(session, df): @@ -82,8 +87,6 @@ def check_result_with_and_without_breakdown(session, df): def test_no_valid_nodes_found(session, large_query_df, caplog): """Test large query breakdown works with default bounds""" - set_bounds(300, 600) - base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -104,7 +107,6 @@ def test_no_valid_nodes_found(session, large_query_df, caplog): def test_large_query_breakdown_with_cte_optimization(session): """Test large query breakdown works with cte optimized plan""" - set_bounds(300, 600) session._cte_optimization_enabled = True df0 = session.sql("select 2 as b, 32 as c") df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1) @@ -131,7 +133,6 @@ def test_large_query_breakdown_with_cte_optimization(session): def test_save_as_table(session, large_query_df): - set_bounds(300, 600) table_name = Utils.random_table_name() with session.query_history() as history: large_query_df.write.save_as_table(table_name, mode="overwrite") @@ -146,7 +147,6 @@ def test_save_as_table(session, large_query_df): def test_update_delete_merge(session, large_query_df): - set_bounds(300, 600) session._large_query_breakdown_enabled = True table_name = Utils.random_table_name() df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"]) @@ -186,7 +186,6 @@ def test_update_delete_merge(session, large_query_df): def test_copy_into_location(session, large_query_df): - set_bounds(300, 600) remote_file_path = f"{session.get_session_stage()}/df.parquet" with session.query_history() as history: large_query_df.write.copy_into_location( @@ -204,7 +203,6 @@ def test_copy_into_location(session, large_query_df): def test_pivot_unpivot(session): - set_bounds(300, 600) session.sql( """create or replace temp table monthly_sales(A int, B int, month text) as select * from values @@ -243,7 +241,6 @@ def test_pivot_unpivot(session): def test_sort(session): - set_bounds(300, 600) base_df = session.sql("select 1 as A, 2 as B") df1 = base_df.with_column("A", col("A") + lit(1)) df2 = base_df.with_column("B", col("B") + lit(1)) @@ -276,7 +273,6 @@ def test_sort(session): def test_multiple_query_plan(session, large_query_df): - set_bounds(300, 600) original_threshold = analyzer.ARRAY_BIND_THRESHOLD try: analyzer.ARRAY_BIND_THRESHOLD = 2 @@ -314,7 +310,6 @@ def test_multiple_query_plan(session, large_query_df): def test_optimization_skipped_with_transaction(session, large_query_df, caplog): """Test large query breakdown is skipped when transaction is enabled""" - set_bounds(300, 600) session.sql("begin").collect() assert Utils.is_active_transaction(session) with caplog.at_level(logging.DEBUG): @@ -330,7 +325,6 @@ def test_optimization_skipped_with_transaction(session, large_query_df, caplog): def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): """Test large query breakdown is skipped plan is a view or dynamic table""" - set_bounds(300, 600) source_table = Utils.random_table_name() table_name = Utils.random_table_name() view_name = Utils.random_view_name() @@ -360,7 +354,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog): def test_async_job_with_large_query_breakdown(session, large_query_df): """Test large query breakdown gives same result for async and non-async jobs""" - set_bounds(300, 600) job = large_query_df.collect(block=False) result = job.result() assert result == large_query_df.collect() @@ -376,8 +369,6 @@ def test_async_job_with_large_query_breakdown(session, large_query_df): def test_add_parent_plan_uuid_to_statement_params(session, large_query_df): - set_bounds(300, 600) - with patch.object( session._conn, "run_query", wraps=session._conn.run_query ) as patched_run_query: @@ -400,7 +391,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): """Test complexity bounds affect number of partitions. Also test that when partitions are added, drop table queries are added. """ - set_bounds(300, 600) + set_bounds(session, 300, 600) assert len(large_query_df.queries["queries"]) == 2 assert len(large_query_df.queries["post_actions"]) == 1 assert large_query_df.queries["queries"][0].startswith( @@ -410,7 +401,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): "DROP TABLE If EXISTS" ) - set_bounds(300, 412) + set_bounds(session, 300, 412) assert len(large_query_df.queries["queries"]) == 3 assert len(large_query_df.queries["post_actions"]) == 2 assert large_query_df.queries["queries"][0].startswith( @@ -426,11 +417,11 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): "DROP TABLE If EXISTS" ) - set_bounds(0, 300) + set_bounds(session, 0, 300) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 - reset_bounds() + reset_bounds(session) assert len(large_query_df.queries["queries"]) == 1 assert len(large_query_df.queries["post_actions"]) == 0 diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 0e8bb0d902d..81b852c46c1 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -98,6 +98,24 @@ def test_range_statement(session: Session): ) +def test_literal_complexity_for_snowflake_values(session: Session): + from snowflake.snowpark._internal.analyzer import analyzer + + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + assert_df_subtree_query_complexity( + df1, {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LITERAL: 4} + ) + + try: + original_threshold = analyzer.ARRAY_BIND_THRESHOLD + analyzer.ARRAY_BIND_THRESHOLD = 2 + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + # SELECT "A", "B" from (SELECT * FROM TEMP_TABLE) + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 3}) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + + def test_generator_table_function(session: Session): df1 = session.generator( seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index df0afc1099b..21e77883338 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -5,6 +5,7 @@ import os from functools import partial +from unittest.mock import patch import pytest @@ -719,6 +720,31 @@ def test_eliminate_numeric_sql_value_cast_optimization_enabled_on_session( new_session.eliminate_numeric_sql_value_cast_enabled = None +def test_large_query_breakdown_complexity_bounds(session): + original_bounds = session.large_query_breakdown_complexity_bounds + try: + with pytest.raises(ValueError, match="Expecting a tuple of two integers"): + session.large_query_breakdown_complexity_bounds = (1, 2, 3) + + with pytest.raises( + ValueError, match="Expecting a tuple of lower and upper bound" + ): + session.large_query_breakdown_complexity_bounds = (3, 2) + + with patch.object( + session._conn._telemetry_client, + "send_large_query_breakdown_update_complexity_bounds", + ) as patch_send: + session.large_query_breakdown_complexity_bounds = (1, 2) + assert session.large_query_breakdown_complexity_bounds == (1, 2) + assert patch_send.call_count == 1 + assert patch_send.call_args[0][0] == session.session_id + assert patch_send.call_args[0][1] == 1 + assert patch_send.call_args[0][2] == 2 + finally: + session.large_query_breakdown_complexity_bounds = original_bounds + + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_create_session_from_default_config_file(monkeypatch, db_parameters): import tomlkit diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 7aaa5c9e5dd..39749de76f6 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -1223,3 +1223,51 @@ def send_telemetry(): data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) assert data == expected_data assert type_ == "snowpark_compilation_stage_statistics" + + +def test_temp_table_cleanup(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_telemetry( + session.session_id, + temp_table_cleaner_enabled=True, + num_temp_tables_cleaned=2, + num_temp_tables_created=5, + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleaner_enabled": True, + "num_temp_tables_cleaned": 2, + "num_temp_tables_created": 5, + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup" + + +def test_temp_table_cleanup_exception(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_abnormal_exception_telemetry( + session.session_id, + table_name="table_name_placeholder", + exception_message="exception_message_placeholder", + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleanup_abnormal_exception_table_name": "table_name_placeholder", + "temp_table_cleanup_abnormal_exception_message": "exception_message_placeholder", + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup_abnormal_exception" diff --git a/tests/integ/test_temp_table_cleanup.py b/tests/integ/test_temp_table_cleanup.py index 4ac87661484..cdd97d49937 100644 --- a/tests/integ/test_temp_table_cleanup.py +++ b/tests/integ/test_temp_table_cleanup.py @@ -12,6 +12,7 @@ from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, + warning_dict, ) from snowflake.snowpark.functions import col from tests.utils import IS_IN_STORED_PROC @@ -25,40 +26,61 @@ WAIT_TIME = 1 +@pytest.fixture(autouse=True) +def setup(session): + auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled + session.auto_clean_up_temp_table_enabled = True + yield + session.auto_clean_up_temp_table_enabled = auto_clean_up_temp_table_enabled + + def test_basic(session): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = df1.select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df3 = df1.union_all(df2) df3.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df2 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df3 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 def test_function(session): + session._temp_table_auto_cleaner.ref_count_map.clear() table_name = None def f(session: Session) -> None: @@ -68,13 +90,16 @@ def f(session: Session) -> None: nonlocal table_name table_name = df.table_name assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() f(session) gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_name.split(".")) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.parametrize( @@ -86,33 +111,42 @@ def f(session: Session) -> None: ], ) def test_copy(session, copy_function): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = copy_function(df1).select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 2 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_reference_count_map_multiple_sessions(db_parameters, session): + session._temp_table_auto_cleaner.ref_count_map.clear() new_session = Session.builder.configs(db_parameters).create() + new_session.auto_clean_up_temp_table_enabled = True try: df1 = session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] @@ -120,43 +154,59 @@ def test_reference_count_map_multiple_sessions(db_parameters, session): table_name1 = df1.table_name table_ids1 = table_name1.split(".") assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 1 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = new_session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).cache_result() table_name2 = df2.table_name table_ids2 = table_name2.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids1) assert new_session._table_exists(table_ids2) assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - new_session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not new_session._table_exists(table_ids2) - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 finally: new_session.close() def test_save_as_table_no_drop(session): - session._temp_table_auto_cleaner.start() + session._temp_table_auto_cleaner.ref_count_map.clear() def f(session: Session, temp_table_name: str) -> None: session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).write.save_as_table(temp_table_name, table_type="temp") - assert session._temp_table_auto_cleaner.ref_count_map[temp_table_name] == 0 + assert temp_table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) f(session, temp_table_name) @@ -165,34 +215,25 @@ def f(session: Session, temp_table_name: str) -> None: assert session._table_exists([temp_table_name]) -def test_start_stop(session): - session._temp_table_auto_cleaner.stop() - - df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() - table_name = df1.table_name +def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): + warning_dict.clear() + with caplog.at_level(logging.WARNING): + session.auto_clean_up_temp_table_enabled = False + assert session.auto_clean_up_temp_table_enabled is False + assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() + table_name = df.table_name table_ids = table_name.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 - del df1 + del df gc.collect() - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 - assert not session._temp_table_auto_cleaner.queue.empty() - assert session._table_exists(table_ids) - - session._temp_table_auto_cleaner.start() time.sleep(WAIT_TIME) - assert session._temp_table_auto_cleaner.queue.empty() - assert not session._table_exists(table_ids) - - -def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): - with caplog.at_level(logging.WARNING): - session.auto_clean_up_temp_table_enabled = True + assert session._table_exists(table_ids) + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + session.auto_clean_up_temp_table_enabled = True assert session.auto_clean_up_temp_table_enabled is True - assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text - assert session._temp_table_auto_cleaner.is_alive() - session.auto_clean_up_temp_table_enabled = False - assert session.auto_clean_up_temp_table_enabled is False - assert not session._temp_table_auto_cleaner.is_alive() + with pytest.raises( ValueError, match="value for auto_clean_up_temp_table_enabled must be True or False!", diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index 7c5e3a40bb0..d94c80b8d67 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -166,6 +166,7 @@ def test_overrides(self): # Test for pandas doc when function is not defined on module. assert pandas.read_table.__doc__ in pd.read_table.__doc__ + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_not_redefining_classes_modin_issue_7138(self): original_dataframe_class = pd.DataFrame _init_doc_module() diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index 5434387ba71..6c9edfd024f 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -2,12 +2,20 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from types import MappingProxyType +from unittest import mock + import numpy as np import pytest +import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils +from snowflake.snowpark.functions import greatest, sum as sum_ from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + SnowflakeAggFunc, + _is_supported_snowflake_agg_func, + _SnowparkPandasAggregation, check_is_aggregation_supported_in_snowflake, - is_supported_snowflake_agg_func, + get_snowflake_agg_func, ) @@ -53,8 +61,8 @@ ("quantile", {}, 1, False), ], ) -def test_is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: - assert is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid +def test__is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: + assert _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid @pytest.mark.parametrize( @@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args( agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0 ) assert can_be_distributed == expected_result + + +@pytest.mark.parametrize( + "agg_func, agg_kwargs, axis, expected", + [ + (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)), + ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)), + ("test", {}, 0, None), + ], +) +def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected): + result = get_snowflake_agg_func(agg_func, agg_kwargs, axis) + if expected is None: + assert result is None + else: + assert result == expected + + +def test_get_snowflake_agg_func_with_no_implementation_on_axis_0(): + """Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0.""" + # We have to patch the internal dictionary + # _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is + # no real function that we support on axis=1 but not on axis=0. + with mock.patch.object( + aggregation_utils, + "_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION", + MappingProxyType( + { + "max": _SnowparkPandasAggregation( + preserves_snowpark_pandas_types=True, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=greatest, + ) + } + ), + ): + assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index c31e5cc6290..c9b8a1ce38d 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -87,30 +87,37 @@ def test_expression(): a = Expression() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] b = Expression(child=UnresolvedAttribute("a")) assert b.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert b.dependent_column_names_with_duplication() == [] # root class Expression always returns empty dependency def test_literal(): a = Literal(5) assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] def test_attribute(): a = Attribute("A", IntegerType()) assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] def test_unresolved_attribute(): a = UnresolvedAttribute("A") assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] b = UnresolvedAttribute("a > 1", is_sql_text=True) assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] c = UnresolvedAttribute("$1 > 1", is_sql_text=True) assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert c.dependent_column_names_with_duplication() == ["$"] def test_case_when(): @@ -118,46 +125,85 @@ def test_case_when(): b = Column("b") z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e")) assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'} + # verify column '"A"', '"B"' occurred twice in the dependency columns + assert z._expression.dependent_column_names_with_duplication() == [ + '"A"', + '"B"', + '"C"', + '"A"', + '"B"', + '"D"', + '"E"', + ] def test_collate(): a = Collate(UnresolvedAttribute("a"), "spec") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_function_expression(): a = FunctionExpression("test_func", [UnresolvedAttribute(x) for x in "abcd"], False) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # expressions with duplicated dependent column + b = FunctionExpression( + "test_func", [UnresolvedAttribute(x) for x in "abcdad"], False + ) + assert b.dependent_column_names() == set("abcd") + assert b.dependent_column_names_with_duplication() == list("abcdad") def test_in_expression(): a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") def test_like(): a = Like(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = Like(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_list_agg(): a = ListAgg(UnresolvedAttribute("a"), ",", True) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_multiple_expression(): a = MultipleExpression([UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + a = MultipleExpression([UnresolvedAttribute(x) for x in "abcdbea"]) + assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("abcdbea") def test_reg_exp(): a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + b = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_scalar_subquery(): a = ScalarSubquery(None) assert a.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert a.dependent_column_names_with_duplication() == list(COLUMN_DEPENDENCY_DOLLAR) def test_snowflake_udf(): @@ -165,21 +211,42 @@ def test_snowflake_udf(): "udf_name", [UnresolvedAttribute(x) for x in "abcd"], IntegerType() ) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + b = SnowflakeUDF( + "udf_name", [UnresolvedAttribute(x) for x in "abcdfc"], IntegerType() + ) + assert b.dependent_column_names() == set("abcdf") + assert b.dependent_column_names_with_duplication() == list("abcdfc") def test_star(): a = Star([Attribute(x, IntegerType()) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + b = Star([]) + assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] def test_subfield_string(): a = SubfieldString(UnresolvedAttribute("a"), "field") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_within_group(): a = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") + + b = WithinGroup( + UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"] + ) + assert b.dependent_column_names() == set("abcde") + assert b.dependent_column_names_with_duplication() == list("eabcdea") @pytest.mark.parametrize( @@ -189,16 +256,19 @@ def test_within_group(): def test_unary_expression(expression_class): a = expression_class(child=UnresolvedAttribute("a")) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_alias(): a = Alias(child=Add(UnresolvedAttribute("a"), UnresolvedAttribute("b")), name="c") assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_cast(): a = Cast(UnresolvedAttribute("a"), IntegerType()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] @pytest.mark.parametrize( @@ -234,6 +304,19 @@ def test_binary_expression(expression_class): assert b.dependent_column_names() == {"B"} assert binary_expression.dependent_column_names() == {"A", "B"} + assert a.dependent_column_names_with_duplication() == ["A"] + assert b.dependent_column_names_with_duplication() == ["B"] + assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"] + + # hierarchical expressions with duplication + hierarchical_binary_expression = expression_class(expression_class(a, b), b) + assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"} + assert hierarchical_binary_expression.dependent_column_names_with_duplication() == [ + "A", + "B", + "B", + ] + @pytest.mark.parametrize( "expression_class", @@ -253,6 +336,18 @@ def test_grouping_set(expression_class): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] + + # with duplication + b = expression_class( + [ + UnresolvedAttribute("a"), + UnresolvedAttribute("a"), + UnresolvedAttribute("c"), + ] + ) + assert b.dependent_column_names() == {"a", "c"} + assert b.dependent_column_names_with_duplication() == ["a", "a", "c"] def test_grouping_sets_expression(): @@ -263,11 +358,13 @@ def test_grouping_sets_expression(): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] def test_sort_order(): a = SortOrder(UnresolvedAttribute("a"), Ascending()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_specified_window_frame(): @@ -275,12 +372,21 @@ def test_specified_window_frame(): RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("b") ) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("a") + ) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] @pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead]) def test_rank_related_function_expression(expression_class): a = expression_class(UnresolvedAttribute("a"), 1, UnresolvedAttribute("b"), False) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_window_spec_definition(): @@ -295,6 +401,7 @@ def test_window_spec_definition(): ), ) assert a.dependent_column_names() == set("abcdef") + assert a.dependent_column_names_with_duplication() == list("abcdef") def test_window_expression(): @@ -310,6 +417,23 @@ def test_window_expression(): ) a = WindowExpression(UnresolvedAttribute("x"), window_spec_definition) assert a.dependent_column_names() == set("abcdefx") + assert a.dependent_column_names_with_duplication() == list("xabcdef") + + +def test_window_expression_with_duplication_columns(): + window_spec_definition = WindowSpecDefinition( + [UnresolvedAttribute("a"), UnresolvedAttribute("b")], + [ + SortOrder(UnresolvedAttribute("c"), Ascending()), + SortOrder(UnresolvedAttribute("a"), Ascending()), + ], + SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("e"), UnresolvedAttribute("f") + ), + ) + a = WindowExpression(UnresolvedAttribute("e"), window_spec_definition) + assert a.dependent_column_names() == set("abcef") + assert a.dependent_column_names_with_duplication() == list("eabcaef") @pytest.mark.parametrize( @@ -325,3 +449,4 @@ def test_window_expression(): def test_other_window_expressions(expression_class): a = expression_class() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 262c9e82c44..370ee455d62 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -112,6 +112,7 @@ def test_used_scoped_temp_object(): def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._telemetry_client = mock.Mock() fake_connection.is_closed = MagicMock(return_value=False) exception_msg = "Mock exception for session.cancel_all" fake_connection.run_query = MagicMock(side_effect=Exception(exception_msg))