diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a358..90f2003251d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,14 @@ #### New Features - Added support for `TimedeltaIndex.mean` method. +- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. +- Added support for `by`, `left_by`, `right_by`, `left_index`, and `right_index` for `pd.merge_asof`. +#### Bug Fixes + +- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. +- Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns. +- Fixed `inplace` argument for `Series` objects derived from other `Series` objects. ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. @@ -124,6 +131,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added support for `Series.dt.total_seconds` method. - Added support for `DataFrame.apply(axis=0)`. - Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`. +- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`. #### Improvements diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 68b1935da96..3afe671aee7 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -82,9 +82,9 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``snap`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | | +| ``tz_convert`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | | +| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 797ef3bbd59..b3a71f023a9 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -38,9 +38,7 @@ Data manipulations +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. | -| | | , ``left_index``, ``right_index``| | -| | | , ``suffixes``, ``tolerance`` | | +| ``merge_asof`` | P | ``suffixes``, ``tolerance`` | ``N`` if param ``direction`` is ``nearest`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge_ordered`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 76e91b7da92..d8622299ea9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if ( - len(logical_plan.output) * len(logical_plan.data) - < ARRAY_BIND_THRESHOLD - ): + if not logical_plan.is_large_local_data: return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 3ed969caada..22591f55e47 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -29,6 +30,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.left, self.right) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a2d21db4eb2..a7cb5fd97a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -35,6 +35,13 @@ def derive_dependent_columns( *expressions: "Optional[Expression]", ) -> Optional[AbstractSet[str]]: + """ + Given set of expressions, derive the set of columns that the expressions dependents on. + + Note, the returned dependent columns is a set without duplication. For example, given expression + concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has + occurred in the given expression twice. + """ result = set() for exp in expressions: if exp is not None: @@ -48,6 +55,23 @@ def derive_dependent_columns( return result +def derive_dependent_columns_with_duplication( + *expressions: "Optional[Expression]", +) -> List[str]: + """ + Given set of expressions, derive the list of columns that the expression dependents on. + + Note, the returned columns will have duplication if the column occurred more than once in + the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result + [col1, col1, col2], where col1 occurred twice in the result. + """ + result = [] + for exp in expressions: + if exp is not None: + result.extend(exp.dependent_column_names_with_duplication()) + return result + + class Expression: """Consider removing attributes, and adding properties and methods. A subclass of Expression may have no child, one child, or multiple children. @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY + def dependent_column_names_with_duplication(self) -> List[str]: + return [] + @property def pretty_name(self) -> str: """Returns a user-facing string representation of this expression's name. @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + def dependent_column_names_with_duplication(self) -> List[str]: + return list(COLUMN_DEPENDENCY_DOLLAR) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.expressions) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.columns, *self.values) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN @@ -212,6 +248,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + def dependent_column_names_with_duplication(self) -> List[str]: + return [self.name] + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + derive_dependent_columns_with_duplication(*self.expressions) + if self.expressions + else [] # we currently do not handle * dependency + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} @@ -278,6 +324,14 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + [] + if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL) + or (self._dependent_column_names is None) + else list(self._dependent_column_names) + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern @@ -400,6 +457,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -510,6 +579,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) + @property def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) @@ -549,13 +624,21 @@ def __init__( self.branches = branches self.else_value = else_value - def dependent_column_names(self) -> Optional[AbstractSet[str]]: + @property + def _child_expressions(self) -> List[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) if self.else_value is not None: exps.append(self.else_value) - return derive_dependent_columns(*exps) + + return exps + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self._child_expressions) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self._child_expressions) @property def plan_node_category(self) -> PlanNodeCategory: @@ -602,6 +685,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.col) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.exprs) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 84cd63fd87d..012940471d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.group_by_exprs) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + def dependent_column_names_with_duplication(self) -> List[str]: + flattened_args = [exp for sublist in self.args for exp in sublist] + return derive_dependent_columns_with_duplication(*flattened_args) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index e3e032cd94b..aa8730dcf7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -144,10 +144,27 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def is_large_local_data(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.is_large_local_data: + # When the number of literals exceeds the threshold, we generate 3 queries: + # 1. create table query + # 2. insert into table query + # 3. select * from table query + # We only consider the complexity from the final select * query since other queries + # are built based on it. + return { + PlanNodeCategory.COLUMN: 1, + } + + # If we stay under the threshold, we generate a single query: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) - # TODO: use ARRAY_BIND_THRESHOLD return { PlanNodeCategory.COLUMN: len(self.output), PlanNodeCategory.LITERAL: len(self.data) * len(self.output), diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..82451245e4c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, List, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) @@ -55,3 +56,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5886e11069..1ae08e8fde2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -36,6 +37,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 69db3f265ce..4381c4a2e22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -71,6 +72,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.lower, self.upper) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + *self.partition_spec, *self.order_spec, self.frame_spec + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec @@ -138,6 +147,11 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + self.window_function, self.window_spec + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW @@ -171,6 +185,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.default) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py index 836628345aa..8d16383a4ce 100644 --- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py +++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py @@ -58,11 +58,6 @@ ) from snowflake.snowpark.session import Session -# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT -# in Snowflake. This is the limit where we start seeing compilation errors. -COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 -COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000 - _logger = logging.getLogger(__name__) @@ -123,6 +118,12 @@ def __init__( self._query_generator = query_generator self.logical_plans = logical_plans self._parent_map = defaultdict(set) + self.complexity_score_lower_bound = ( + session.large_query_breakdown_complexity_bounds[0] + ) + self.complexity_score_upper_bound = ( + session.large_query_breakdown_complexity_bounds[1] + ) def apply(self) -> List[LogicalPlan]: if is_active_transaction(self.session): @@ -183,13 +184,13 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]: complexity_score = get_complexity_score(root.cumulative_node_complexity) _logger.debug(f"Complexity score for root {type(root)} is: {complexity_score}") - if complexity_score <= COMPLEXITY_SCORE_UPPER_BOUND: + if complexity_score <= self.complexity_score_upper_bound: # Skip optimization if the complexity score is within the upper bound. return [root] plans = [] # TODO: SNOW-1617634 Have a one pass algorithm to find the valid node for partitioning - while complexity_score > COMPLEXITY_SCORE_UPPER_BOUND: + while complexity_score > self.complexity_score_upper_bound: child = self._find_node_to_breakdown(root) if child is None: _logger.debug( @@ -277,7 +278,9 @@ def _is_node_valid_to_breakdown(self, node: LogicalPlan) -> Tuple[bool, int]: """ score = get_complexity_score(node.cumulative_node_complexity) valid_node = ( - COMPLEXITY_SCORE_LOWER_BOUND < score < COMPLEXITY_SCORE_UPPER_BOUND + self.complexity_score_lower_bound + < score + < self.complexity_score_upper_bound ) and self._is_node_pipeline_breaker(node) if valid_node: _logger.debug( diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 211b66820ec..3e6dba71be4 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -16,8 +16,6 @@ ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.compiler.large_query_breakdown import ( - COMPLEXITY_SCORE_LOWER_BOUND, - COMPLEXITY_SCORE_UPPER_BOUND, LargeQueryBreakdown, ) from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import ( @@ -128,10 +126,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: summary_value = { TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, - CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( - COMPLEXITY_SCORE_LOWER_BOUND, - COMPLEXITY_SCORE_UPPER_BOUND, - ), + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds, CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 223b6a1326f..be61a1ac924 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -11,6 +11,9 @@ class CompilationStageTelemetryField(Enum): "snowpark_large_query_breakdown_optimization_skipped" ) TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics" + TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = ( + "snowpark_large_query_breakdown_update_complexity_bounds" + ) # keys KEY_REASON = "reason" diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 8b9ef2acccb..025eb57c540 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -79,6 +79,20 @@ class TelemetryField(Enum): QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" QUERY_PLAN_COMPLEXITY = "query_plan_complexity" + # temp table cleanup + TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup" + NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned" + NUM_TEMP_TABLES_CREATED = "num_temp_tables_created" + TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled" + TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = ( + "snowpark_temp_table_cleanup_abnormal_exception" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = ( + "temp_table_cleanup_abnormal_exception_table_name" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = ( + "temp_table_cleanup_abnormal_exception_message" + ) # These DataFrame APIs call other DataFrame APIs @@ -374,7 +388,7 @@ def send_sql_simplifier_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.SQL_SIMPLIFIER_ENABLED.value: True, + TelemetryField.SQL_SIMPLIFIER_ENABLED.value: sql_simplifier_enabled, }, } self.send(message) @@ -428,7 +442,7 @@ def send_large_query_breakdown_telemetry( ), TelemetryField.KEY_DATA.value: { TelemetryField.SESSION_ID.value: session_id, - TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: True, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: value, }, } self.send(message) @@ -464,3 +478,60 @@ def send_large_query_optimization_skipped_telemetry( }, } self.send(message) + + def send_temp_table_cleanup_telemetry( + self, + session_id: str, + temp_table_cleaner_enabled: bool, + num_temp_tables_cleaned: int, + num_temp_tables_created: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled, + TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned, + TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created, + }, + } + self.send(message) + + def send_temp_table_cleanup_abnormal_exception_telemetry( + self, + session_id: str, + table_name: str, + exception_message: str, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message, + }, + } + self.send(message) + + def send_large_query_breakdown_update_complexity_bounds( + self, session_id: int, lower_bound: int, upper_bound: int + ): + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.KEY_DATA.value: { + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( + lower_bound, + upper_bound, + ), + }, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index b9055c6fc58..4fa17498d34 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -4,9 +4,7 @@ import logging import weakref from collections import defaultdict -from queue import Empty, Queue -from threading import Event, Thread -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable @@ -33,12 +31,6 @@ def __init__(self, session: "Session") -> None: # to its reference count for later temp table management # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) - # unused temp table will be put into the queue for cleanup - self.queue: Queue = Queue() - # thread for removing temp tables (running DROP TABLE sql) - self.cleanup_thread: Optional[Thread] = None - # An event managing a flag that indicates whether the cleaner is started - self.stop_event = Event() def add(self, table: SnowflakeTable) -> None: self.ref_count_map[table.name] += 1 @@ -46,61 +38,60 @@ def add(self, table: SnowflakeTable) -> None: # and this table will be dropped finally _ = weakref.finalize(table, self._delete_ref_count, table.name) - def _delete_ref_count(self, name: str) -> None: + def _delete_ref_count(self, name: str) -> None: # pragma: no cover """ Decrements the reference count of a temporary table, and if the count reaches zero, puts this table in the queue for cleanup. """ self.ref_count_map[name] -= 1 if self.ref_count_map[name] == 0: - self.ref_count_map.pop(name) - # clean up - self.queue.put(name) + if self.session.auto_clean_up_temp_table_enabled: + self.drop_table(name) elif self.ref_count_map[name] < 0: logging.debug( f"Unexpected reference count {self.ref_count_map[name]} for table {name}" ) - def process_cleanup(self) -> None: - while not self.stop_event.is_set(): - try: - # it's non-blocking after timeout and become interruptable with stop_event - # it will raise an `Empty` exception if queue is empty after timeout, - # then we catch this exception and avoid breaking loop - table_name = self.queue.get(timeout=1) - self.drop_table(table_name) - except Empty: - continue - - def drop_table(self, name: str) -> None: + def drop_table(self, name: str) -> None: # pragma: no cover common_log_text = f"temp table {name} in session {self.session.session_id}" - logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}") + logging.debug(f"Ready to drop {common_log_text}") + query_id = None try: - # TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported - with self.session._conn._conn.cursor() as cursor: - cursor.execute( - f"drop table if exists {name} /* internal query to drop unused temp table */", - _statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}, + async_job = self.session.sql( + f"drop table if exists {name} /* internal query to drop unused temp table */", + )._internal_collect_with_tag_no_telemetry( + block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name} + ) + query_id = async_job.query_id + logging.debug(f"Dropping {common_log_text} with query id {query_id}") + except Exception as ex: # pragma: no cover + warning_message = f"Failed to drop {common_log_text}, exception: {ex}" + logging.warning(warning_message) + if query_id is None: + # If no query_id is available, it means the query haven't been accepted by gs, + # and it won't occur in our job_etl_view, send a separate telemetry for recording. + self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry( + self.session.session_id, + name, + str(ex), ) - logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}") - except Exception as ex: - logging.warning( - f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}" - ) # pragma: no cover - - def is_alive(self) -> bool: - return self.cleanup_thread is not None and self.cleanup_thread.is_alive() - - def start(self) -> None: - self.stop_event.clear() - if not self.is_alive(): - self.cleanup_thread = Thread(target=self.process_cleanup) - self.cleanup_thread.start() def stop(self) -> None: """ - The cleaner will stop immediately and leave unfinished temp tables in the queue. + Stops the cleaner (no-op) and sends the telemetry. """ - self.stop_event.set() - if self.is_alive(): - self.cleanup_thread.join() + self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry( + self.session.session_id, + temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled, + num_temp_tables_cleaned=self.num_temp_tables_cleaned, + num_temp_tables_created=self.num_temp_tables_created, + ) + + @property + def num_temp_tables_created(self) -> int: + return len(self.ref_count_map) + + @property + def num_temp_tables_cleaned(self) -> int: + # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled + return sum(v == 0 for v in self.ref_count_map.values()) diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 6960d0eb629..8f9834630b7 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -88,13 +88,13 @@ # TODO: SNOW-851745 make sure add all Snowpark pandas API general functions from modin.pandas import plotting # type: ignore[import] +from modin.pandas.dataframe import DataFrame from modin.pandas.series import Series from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, ) -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.general import ( bdate_range, concat, @@ -185,10 +185,8 @@ modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP) -# For any method defined on Series/DF, add telemetry to it if it: -# 1. Is defined directly on an upstream class -# 2. The method name does not start with an _, or is in TELEMETRY_PRIVATE_METHODS - +# For any method defined on Series/DF, add telemetry to it if the method name does not start with an +# _, or the method is in TELEMETRY_PRIVATE_METHODS. This includes methods defined as an extension/override. for attr_name in dir(Series): # Since Series is defined in upstream Modin, all of its members were either defined upstream # or overridden by extension. @@ -197,11 +195,9 @@ try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name)) ) - -# TODO: SNOW-1063346 -# Since we still use the vendored version of DataFrame and the overrides for the top-level -# namespace haven't been performed yet, we need to set properties on the vendored version for attr_name in dir(DataFrame): + # Since DataFrame is defined in upstream Modin, all of its members were either defined upstream + # or overridden by extension. if not attr_name.startswith("_") or attr_name in TELEMETRY_PRIVATE_METHODS: register_dataframe_accessor(attr_name)( try_add_telemetry_to_attribute(attr_name, getattr(DataFrame, attr_name)) diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py index 6a34f50e42a..47d44835fe4 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py @@ -19,9 +19,12 @@ # existing code originally distributed by the Modin project, under the Apache License, # Version 2.0. -from modin.pandas.api.extensions import register_series_accessor +from modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) -from .extensions import register_dataframe_accessor, register_pd_accessor +from .extensions import register_pd_accessor __all__ = [ "register_dataframe_accessor", diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py index 45896292e74..05424c92072 100644 --- a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py @@ -86,49 +86,6 @@ def decorator(new_attr: Any): return decorator -def register_dataframe_accessor(name: str): - """ - Registers a dataframe attribute with the name provided. - This is a decorator that assigns a new attribute to DataFrame. It can be used - with the following syntax: - ``` - @register_dataframe_accessor("new_method") - def my_new_dataframe_method(*args, **kwargs): - # logic goes here - return - ``` - The new attribute can then be accessed with the name provided: - ``` - df.new_method(*my_args, **my_kwargs) - ``` - - If you want a property accessor, you must annotate with @property - after the call to this function: - ``` - @register_dataframe_accessor("new_prop") - @property - def my_new_dataframe_property(*args, **kwargs): - return _prop - ``` - - Parameters - ---------- - name : str - The name of the attribute to assign to DataFrame. - Returns - ------- - decorator - Returns the decorator function. - """ - import snowflake.snowpark.modin.pandas as pd - - return _set_attribute_on_obj( - name, - pd.dataframe._DATAFRAME_EXTENSIONS_, - pd.dataframe.DataFrame, - ) - - def register_pd_accessor(name: str): """ Registers a pd namespace attribute with the name provided. diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py deleted file mode 100644 index 83893e83e9c..00000000000 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ /dev/null @@ -1,3511 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -# Licensed to Modin Development Team under one or more contributor license agreements. -# See the NOTICE file distributed with this work for additional information regarding -# copyright ownership. The Modin Development Team licenses this file to you under the -# Apache License, Version 2.0 (the "License"); you may not use this file except in -# compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under -# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific language -# governing permissions and limitations under the License. - -# Code in this file may constitute partial or total reimplementation, or modification of -# existing code originally distributed by the Modin project, under the Apache License, -# Version 2.0. - -"""Module houses ``DataFrame`` class, that is distributed version of ``pandas.DataFrame``.""" - -from __future__ import annotations - -import collections -import datetime -import functools -import itertools -import re -import sys -import warnings -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from logging import getLogger -from typing import IO, Any, Callable, Literal - -import numpy as np -import pandas -from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor -from modin.pandas.base import BasePandasDataset - -# from . import _update_engine -from modin.pandas.iterator import PartitionIterator -from modin.pandas.series import Series -from pandas._libs.lib import NoDefault, no_default -from pandas._typing import ( - AggFuncType, - AnyArrayLike, - Axes, - Axis, - CompressionOptions, - FilePath, - FillnaOptions, - IgnoreRaise, - IndexLabel, - Level, - PythonFuncType, - Renamer, - Scalar, - StorageOptions, - Suffixes, - WriteBuffer, -) -from pandas.core.common import apply_if_callable, is_bool_indexer -from pandas.core.dtypes.common import ( - infer_dtype_from_object, - is_bool_dtype, - is_dict_like, - is_list_like, - is_numeric_dtype, -) -from pandas.core.dtypes.inference import is_hashable, is_integer -from pandas.core.indexes.frozen import FrozenList -from pandas.io.formats.printing import pprint_thing -from pandas.util._validators import validate_bool_kwarg - -from snowflake.snowpark.modin import pandas as pd -from snowflake.snowpark.modin.pandas.groupby import ( - DataFrameGroupBy, - validate_groupby_args, -) -from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( - SnowparkPandasRowPartitionIterator, -) -from snowflake.snowpark.modin.pandas.utils import ( - create_empty_native_pandas_frame, - from_non_pandas, - from_pandas, - is_scalar, - raise_if_native_pandas_objects, - replace_external_data_keys_with_empty_pandas_series, - replace_external_data_keys_with_query_compiler, -) -from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated -from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike -from snowflake.snowpark.modin.plugin.utils.error_message import ( - ErrorMessage, - dataframe_not_implemented, -) -from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP -from snowflake.snowpark.modin.plugin.utils.warning_message import ( - SET_DATAFRAME_ATTRIBUTE_WARNING, - WarningMessage, -) -from snowflake.snowpark.modin.utils import _inherit_docstrings, hashable, to_pandas -from snowflake.snowpark.udf import UserDefinedFunction - -logger = getLogger(__name__) - -DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( - "Currently do not support Series or list-like keys with range-like values" -) - -DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( - "Currently do not support assigning a slice value as if it's a scalar value" -) - -DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( - "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " - "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " - "can work on the entire DataFrame in one shot." -) - -# Dictionary of extensions assigned to this class -_DATAFRAME_EXTENSIONS_ = {} - - -@_inherit_docstrings( - pandas.DataFrame, - excluded=[ - pandas.DataFrame.flags, - pandas.DataFrame.cov, - pandas.DataFrame.merge, - pandas.DataFrame.reindex, - pandas.DataFrame.to_parquet, - pandas.DataFrame.fillna, - ], - apilink="pandas.DataFrame", -) -class DataFrame(BasePandasDataset): - _pandas_class = pandas.DataFrame - - def __init__( - self, - data=None, - index=None, - columns=None, - dtype=None, - copy=None, - query_compiler=None, - ) -> None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Siblings are other dataframes that share the same query compiler. We - # use this list to update inplace when there is a shallow copy. - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native - - self._siblings = [] - - # Engine.subscribe(_update_engine) - if isinstance(data, (DataFrame, Series)): - self._query_compiler = data._query_compiler.copy() - if index is not None and any(i not in data.index for i in index): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if isinstance(data, Series): - # We set the column name if it is not in the provided Series - if data.name is None: - self.columns = [0] if columns is None else columns - # If the columns provided are not in the named Series, pandas clears - # the DataFrame and sets columns to the columns provided. - elif columns is not None and data.name not in columns: - self._query_compiler = from_pandas( - self.__constructor__(columns=columns) - )._query_compiler - if index is not None: - self._query_compiler = data.loc[index]._query_compiler - elif columns is None and index is None: - data._add_sibling(self) - else: - if columns is not None and any(i not in data.columns for i in columns): - ErrorMessage.not_implemented( - "Passing non-existant columns or index values to constructor not" - + " yet implemented." - ) # pragma: no cover - if index is None: - index = slice(None) - if columns is None: - columns = slice(None) - self._query_compiler = data.loc[index, columns]._query_compiler - - # Check type of data and use appropriate constructor - elif query_compiler is None: - distributed_frame = from_non_pandas(data, index, columns, dtype) - if distributed_frame is not None: - self._query_compiler = distributed_frame._query_compiler - return - - if isinstance(data, pandas.Index): - pass - elif is_list_like(data) and not is_dict_like(data): - old_dtype = getattr(data, "dtype", None) - values = [ - obj._to_pandas() if isinstance(obj, Series) else obj for obj in data - ] - if isinstance(data, np.ndarray): - data = np.array(values, dtype=old_dtype) - else: - try: - data = type(data)(values, dtype=old_dtype) - except TypeError: - data = values - elif is_dict_like(data) and not isinstance( - data, (pandas.Series, Series, pandas.DataFrame, DataFrame) - ): - if columns is not None: - data = {key: value for key, value in data.items() if key in columns} - - if len(data) and all(isinstance(v, Series) for v in data.values()): - from .general import concat - - new_qc = concat( - data.values(), axis=1, keys=data.keys() - )._query_compiler - - if dtype is not None: - new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) - if index is not None: - new_qc = new_qc.reindex( - axis=0, labels=try_convert_index_to_native(index) - ) - if columns is not None: - new_qc = new_qc.reindex( - axis=1, labels=try_convert_index_to_native(columns) - ) - - self._query_compiler = new_qc - return - - data = { - k: v._to_pandas() if isinstance(v, Series) else v - for k, v in data.items() - } - pandas_df = pandas.DataFrame( - data=try_convert_index_to_native(data), - index=try_convert_index_to_native(index), - columns=try_convert_index_to_native(columns), - dtype=dtype, - copy=copy, - ) - self._query_compiler = from_pandas(pandas_df)._query_compiler - else: - self._query_compiler = query_compiler - - def __repr__(self): - """ - Return a string representation for a particular ``DataFrame``. - - Returns - ------- - str - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - num_rows = pandas.get_option("display.max_rows") or len(self) - # see _repr_html_ for comment, allow here also all column behavior - num_cols = pandas.get_option("display.max_columns") or len(self.columns) - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") - result = repr(repr_df) - - # if truncated, add shape information - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # The split here is so that we don't repr pandas row lengths. - return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( - row_count, col_count - ) - else: - return result - - def _repr_html_(self): # pragma: no cover - """ - Return a html representation for a particular ``DataFrame``. - - Returns - ------- - str - - Notes - ----- - Supports pandas `display.max_rows` and `display.max_columns` options. - """ - num_rows = pandas.get_option("display.max_rows") or 60 - # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow - # here value=0 which means display all columns. - num_cols = pandas.get_option("display.max_columns") - - ( - row_count, - col_count, - repr_df, - ) = self._query_compiler.build_repr_df(num_rows, num_cols) - result = repr_df._repr_html_() - - if is_repr_truncated(row_count, col_count, num_rows, num_cols): - # We split so that we insert our correct dataframe dimensions. - return ( - result.split("
")[0] - + f"
{row_count} rows × {col_count} columns
\n" - ) - else: - return result - - def _get_columns(self) -> pandas.Index: - """ - Get the columns for this Snowpark pandas ``DataFrame``. - - Returns - ------- - Index - The all columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.columns - - def _set_columns(self, new_columns: Axes) -> None: - """ - Set the columns for this Snowpark pandas ``DataFrame``. - - Parameters - ---------- - new_columns : - The new columns to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_inplace( - new_query_compiler=self._query_compiler.set_columns(new_columns) - ) - - columns = property(_get_columns, _set_columns) - - @property - def ndim(self) -> int: - return 2 - - def drop_duplicates( - self, subset=None, keep="first", inplace=False, ignore_index=False - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - """ - Return ``DataFrame`` with duplicate rows removed. - """ - return super().drop_duplicates( - subset=subset, keep=keep, inplace=inplace, ignore_index=ignore_index - ) - - def dropna( - self, - *, - axis: Axis = 0, - how: str | NoDefault = no_default, - thresh: int | NoDefault = no_default, - subset: IndexLabel = None, - inplace: bool = False, - ): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super()._dropna( - axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace - ) - - @property - def dtypes(self): # noqa: RT01, D200 - """ - Return the dtypes in the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.dtypes - - def duplicated( - self, subset: Hashable | Sequence[Hashable] = None, keep: DropKeep = "first" - ): - """ - Return boolean ``Series`` denoting duplicate rows. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - df = self[subset] if subset is not None else self - new_qc = df._query_compiler.duplicated(keep=keep) - duplicates = self._reduce_dimension(new_qc) - # remove Series name which was assigned automatically by .apply in QC - # this is pandas behavior, i.e., if duplicated result is a series, no name is returned - duplicates.name = None - return duplicates - - @property - def empty(self) -> bool: - """ - Indicate whether ``DataFrame`` is empty. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self.columns) == 0 or len(self) == 0 - - @property - def axes(self): - """ - Return a list representing the axes of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return [self.index, self.columns] - - @property - def shape(self) -> tuple[int, int]: - """ - Return a tuple representing the dimensionality of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return len(self), len(self.columns) - - def add_prefix(self, prefix): - """ - Prefix labels with string `prefix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string prefix values into str and adds it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(prefix), substring_type="prefix", axis=1 - ) - ) - - def add_suffix(self, suffix): - """ - Suffix labels with string `suffix`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas converts non-string suffix values into str and appends it to the column labels. - return self.__constructor__( - query_compiler=self._query_compiler.add_substring( - str(suffix), substring_type="suffix", axis=1 - ) - ) - - @dataframe_not_implemented() - def map( - self, func, na_action: str | None = None, **kwargs - ) -> DataFrame: # pragma: no cover - if not callable(func): - raise ValueError(f"'{type(func)}' object is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.map(func, na_action=na_action, **kwargs) - ) - - def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not callable(func): - raise TypeError(f"{func} is not callable") - return self.__constructor__( - query_compiler=self._query_compiler.applymap( - func, na_action=na_action, **kwargs - ) - ) - - def aggregate( - self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().aggregate(func, axis, *args, **kwargs) - - agg = aggregate - - def apply( - self, - func: AggFuncType | UserDefinedFunction, - axis: Axis = 0, - raw: bool = False, - result_type: Literal["expand", "reduce", "broadcast"] | None = None, - args=(), - **kwargs, - ): - """ - Apply a function along an axis of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - query_compiler = self._query_compiler.apply( - func, - axis, - raw=raw, - result_type=result_type, - args=args, - **kwargs, - ) - if not isinstance(query_compiler, type(self._query_compiler)): - # A scalar was returned - return query_compiler - - # If True, it is an unamed series. - # Theoretically, if df.apply returns a Series, it will only be an unnamed series - # because the function is supposed to be series -> scalar. - if query_compiler._modin_frame.is_unnamed_series(): - return Series(query_compiler=query_compiler) - else: - return self.__constructor__(query_compiler=query_compiler) - - def groupby( - self, - by=None, - axis: Axis | NoDefault = no_default, - level: IndexLabel | None = None, - as_index: bool = True, - sort: bool = True, - group_keys: bool = True, - observed: bool | NoDefault = no_default, - dropna: bool = True, - ): - """ - Group ``DataFrame`` using a mapper or by a ``Series`` of columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if axis is not no_default: - axis = self._get_axis_number(axis) - if axis == 1: - warnings.warn( - "DataFrame.groupby with axis=1 is deprecated. Do " - + "`frame.T.groupby(...)` without axis instead.", - FutureWarning, - stacklevel=1, - ) - else: - warnings.warn( - "The 'axis' keyword in DataFrame.groupby is deprecated and " - + "will be removed in a future version.", - FutureWarning, - stacklevel=1, - ) - else: - axis = 0 - - validate_groupby_args(by, level, observed) - - axis = self._get_axis_number(axis) - - if axis != 0 and as_index is False: - raise ValueError("as_index=False only valid for axis=0") - - idx_name = None - - if ( - not isinstance(by, Series) - and is_list_like(by) - and len(by) == 1 - # if by is a list-like of (None,), we have to keep it as a list because - # None may be referencing a column or index level whose label is - # `None`, and by=None wold mean that there is no `by` param. - and by[0] is not None - ): - by = by[0] - - if hashable(by) and ( - not callable(by) and not isinstance(by, (pandas.Grouper, FrozenList)) - ): - idx_name = by - elif isinstance(by, Series): - idx_name = by.name - if by._parent is self: - # if the SnowSeries comes from the current dataframe, - # convert it to labels directly for easy processing - by = by.name - elif is_list_like(by): - if axis == 0 and all( - ( - (hashable(o) and (o in self)) - or isinstance(o, Series) - or (is_list_like(o) and len(o) == len(self.shape[axis])) - ) - for o in by - ): - # plit 'by's into those that belongs to the self (internal_by) - # and those that doesn't (external_by). For SnowSeries that belongs - # to current DataFrame, we convert it to labels for easy process. - internal_by, external_by = [], [] - - for current_by in by: - if hashable(current_by): - internal_by.append(current_by) - elif isinstance(current_by, Series): - if current_by._parent is self: - internal_by.append(current_by.name) - else: - external_by.append(current_by) # pragma: no cover - else: - external_by.append(current_by) - - by = internal_by + external_by - - return DataFrameGroupBy( - self, - by, - axis, - level, - as_index, - sort, - group_keys, - idx_name, - observed=observed, - dropna=dropna, - ) - - def keys(self): # noqa: RT01, D200 - """ - Get columns of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns - - def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 - """ - Transpose index and columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy: - WarningMessage.ignored_argument( - operation="transpose", - argument="copy", - message="Transpose ignore copy argument in Snowpark pandas API", - ) - - if args: - WarningMessage.ignored_argument( - operation="transpose", - argument="args", - message="Transpose ignores args in Snowpark pandas API", - ) - - return self.__constructor__(query_compiler=self._query_compiler.transpose()) - - T = property(transpose) - - def add( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `add`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "add", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def assign(self, **kwargs): # noqa: PR01, RT01, D200 - """ - Assign new columns to a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - df = self.copy() - for k, v in kwargs.items(): - if callable(v): - df[k] = v(df) - else: - df[k] = v - return df - - @dataframe_not_implemented() - def boxplot( - self, - column=None, - by=None, - ax=None, - fontsize=None, - rot=0, - grid=True, - figsize=None, - layout=None, - return_type=None, - backend=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make a box plot from ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return to_pandas(self).boxplot( - column=column, - by=by, - ax=ax, - fontsize=fontsize, - rot=rot, - grid=grid, - figsize=figsize, - layout=layout, - return_type=return_type, - backend=backend, - **kwargs, - ) - - @dataframe_not_implemented() - def combine( - self, other, func, fill_value=None, overwrite=True - ): # noqa: PR01, RT01, D200 - """ - Perform column-wise combine with another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().combine(other, func, fill_value=fill_value, overwrite=overwrite) - - def compare( - self, - other, - align_axis=1, - keep_shape: bool = False, - keep_equal: bool = False, - result_names=("self", "other"), - ) -> DataFrame: # noqa: PR01, RT01, D200 - """ - Compare to another ``DataFrame`` and show the differences. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - raise TypeError(f"Cannot compare DataFrame to {type(other)}") - other = self._validate_other(other, 0, compare_index=True) - return self.__constructor__( - query_compiler=self._query_compiler.compare( - other, - align_axis=align_axis, - keep_shape=keep_shape, - keep_equal=keep_equal, - result_names=result_names, - ) - ) - - def corr( - self, - method: str | Callable = "pearson", - min_periods: int | None = None, - numeric_only: bool = False, - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - corr_df = self - if numeric_only: - corr_df = self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - return self.__constructor__( - query_compiler=corr_df._query_compiler.corr( - method=method, - min_periods=min_periods, - ) - ) - - @dataframe_not_implemented() - def corrwith( - self, other, axis=0, drop=False, method="pearson", numeric_only=False - ): # noqa: PR01, RT01, D200 - """ - Compute pairwise correlation. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, DataFrame): - other = other._query_compiler.to_pandas() - return self._default_to_pandas( - pandas.DataFrame.corrwith, - other, - axis=axis, - drop=drop, - method=method, - numeric_only=numeric_only, - ) - - @dataframe_not_implemented() - def cov( - self, - min_periods: int | None = None, - ddof: int | None = 1, - numeric_only: bool = False, - ): - """ - Compute pairwise covariance of columns, excluding NA/null values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.cov( - min_periods=min_periods, - ddof=ddof, - numeric_only=numeric_only, - ) - ) - - @dataframe_not_implemented() - def dot(self, other): # noqa: PR01, RT01, D200 - """ - Compute the matrix multiplication between the ``DataFrame`` and `other`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if isinstance(other, BasePandasDataset): - common = self.columns.union(other.index) - if len(common) > len(self.columns) or len(common) > len( - other - ): # pragma: no cover - raise ValueError("Matrices are not aligned") - - if isinstance(other, DataFrame): - return self.__constructor__( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - else: - return self._reduce_dimension( - query_compiler=self._query_compiler.dot( - other.reindex(index=common), squeeze_self=False - ) - ) - - other = np.asarray(other) - if self.shape[1] != other.shape[0]: - raise ValueError( - f"Dot product shape mismatch, {self.shape} vs {other.shape}" - ) - - if len(other.shape) > 1: - return self.__constructor__( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - return self._reduce_dimension( - query_compiler=self._query_compiler.dot(other, squeeze_self=False) - ) - - def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Perform equality comparison of ``DataFrame`` and `other` (binary operator `eq`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("eq", other, axis=axis, level=level) - - def equals(self, other) -> bool: # noqa: PR01, RT01, D200 - """ - Test whether two objects contain the same elements. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, pandas.DataFrame): - # Copy into a Modin DataFrame to simplify logic below - other = self.__constructor__(other) - - if ( - type(self) is not type(other) - or not self.index.equals(other.index) - or not self.columns.equals(other.columns) - ): - return False - - result = self.__constructor__( - query_compiler=self._query_compiler.equals(other._query_compiler) - ) - return result.all(axis=None) - - def _update_var_dicts_in_kwargs(self, expr, kwargs): - """ - Copy variables with "@" prefix in `local_dict` and `global_dict` keys of kwargs. - - Parameters - ---------- - expr : str - The expression string to search variables with "@" prefix. - kwargs : dict - See the documentation for eval() for complete details on the keyword arguments accepted by query(). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if "@" not in expr: - return - frame = sys._getframe() - try: - f_locals = frame.f_back.f_back.f_back.f_back.f_locals - f_globals = frame.f_back.f_back.f_back.f_back.f_globals - finally: - del frame - local_names = set(re.findall(r"@([\w]+)", expr)) - local_dict = {} - global_dict = {} - - for name in local_names: - for dct_out, dct_in in ((local_dict, f_locals), (global_dict, f_globals)): - try: - dct_out[name] = dct_in[name] - except KeyError: - pass - - if local_dict: - local_dict.update(kwargs.get("local_dict") or {}) - kwargs["local_dict"] = local_dict - if global_dict: - global_dict.update(kwargs.get("global_dict") or {}) - kwargs["global_dict"] = global_dict - - @dataframe_not_implemented() - def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Evaluate a string describing operations on ``DataFrame`` columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - self._update_var_dicts_in_kwargs(expr, kwargs) - new_query_compiler = self._query_compiler.eval(expr, **kwargs) - return_type = type( - pandas.DataFrame(columns=self.columns) - .astype(self.dtypes) - .eval(expr, **kwargs) - ).__name__ - if return_type == type(self).__name__: - return self._create_or_update_from_compiler(new_query_compiler, inplace) - else: - if inplace: - raise ValueError("Cannot operate inplace if there is no assignment") - return getattr(sys.modules[self.__module__], return_type)( - query_compiler=new_query_compiler - ) - - def fillna( - self, - value: Hashable | Mapping | Series | DataFrame = None, - *, - method: FillnaOptions | None = None, - axis: Axis | None = None, - inplace: bool = False, - limit: int | None = None, - downcast: dict | None = None, - ) -> DataFrame | None: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().fillna( - self_is_series=False, - value=value, - method=method, - axis=axis, - inplace=inplace, - limit=limit, - downcast=downcast, - ) - - def floordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `floordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "floordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @classmethod - @dataframe_not_implemented() - def from_dict( - cls, data, orient="columns", dtype=None, columns=None - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Construct ``DataFrame`` from dict of array-like or dicts. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_dict( - data, orient=orient, dtype=dtype, columns=columns - ) - ) - - @classmethod - @dataframe_not_implemented() - def from_records( - cls, - data, - index=None, - exclude=None, - columns=None, - coerce_float=False, - nrows=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert structured or record ndarray to ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return from_pandas( - pandas.DataFrame.from_records( - data, - index=index, - exclude=exclude, - columns=columns, - coerce_float=coerce_float, - nrows=nrows, - ) - ) - - def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ge", other, axis=axis, level=level) - - def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get greater than comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("gt", other, axis=axis, level=level) - - @dataframe_not_implemented() - def hist( - self, - column=None, - by=None, - grid=True, - xlabelsize=None, - xrot=None, - ylabelsize=None, - yrot=None, - ax=None, - sharex=False, - sharey=False, - figsize=None, - layout=None, - bins=10, - **kwds, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Make a histogram of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.hist, - column=column, - by=by, - grid=grid, - xlabelsize=xlabelsize, - xrot=xrot, - ylabelsize=ylabelsize, - yrot=yrot, - ax=ax, - sharex=sharex, - sharey=sharey, - figsize=figsize, - layout=layout, - bins=bins, - **kwds, - ) - - def info( - self, - verbose: bool | None = None, - buf: IO[str] | None = None, - max_cols: int | None = None, - memory_usage: bool | str | None = None, - show_counts: bool | None = None, - null_counts: bool | None = None, - ): # noqa: PR01, D200 - """ - Print a concise summary of the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def put_str(src, output_len=None, spaces=2): - src = str(src) - return src.ljust(output_len if output_len else len(src)) + " " * spaces - - def format_size(num): - for x in ["bytes", "KB", "MB", "GB", "TB"]: - if num < 1024.0: - return f"{num:3.1f} {x}" - num /= 1024.0 - return f"{num:3.1f} PB" - - output = [] - - type_line = str(type(self)) - index_line = "SnowflakeIndex" - columns = self.columns - columns_len = len(columns) - dtypes = self.dtypes - dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" - - if max_cols is None: - max_cols = 100 - - exceeds_info_cols = columns_len > max_cols - - if buf is None: - buf = sys.stdout - - if null_counts is None: - null_counts = not exceeds_info_cols - - if verbose is None: - verbose = not exceeds_info_cols - - if null_counts and verbose: - # We're gonna take items from `non_null_count` in a loop, which - # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here - # that will be faster. - non_null_count = self.count()._to_pandas() - - if memory_usage is None: - memory_usage = True - - def get_header(spaces=2): - output = [] - head_label = " # " - column_label = "Column" - null_label = "Non-Null Count" - dtype_label = "Dtype" - non_null_label = " non-null" - delimiter = "-" - - lengths = {} - lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) - lengths["column"] = max( - len(column_label), max(len(pprint_thing(col)) for col in columns) - ) - lengths["dtype"] = len(dtype_label) - dtype_spaces = ( - max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) - - lengths["dtype"] - ) - - header = put_str(head_label, lengths["head"]) + put_str( - column_label, lengths["column"] - ) - if null_counts: - lengths["null"] = max( - len(null_label), - max(len(pprint_thing(x)) for x in non_null_count) - + len(non_null_label), - ) - header += put_str(null_label, lengths["null"]) - header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) - - output.append(header) - - delimiters = put_str(delimiter * lengths["head"]) + put_str( - delimiter * lengths["column"] - ) - if null_counts: - delimiters += put_str(delimiter * lengths["null"]) - delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) - output.append(delimiters) - - return output, lengths - - output.extend([type_line, index_line]) - - def verbose_repr(output): - columns_line = f"Data columns (total {len(columns)} columns):" - header, lengths = get_header() - output.extend([columns_line, *header]) - for i, col in enumerate(columns): - i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) - - to_append = put_str(f" {i}", lengths["head"]) + put_str( - col_s, lengths["column"] - ) - if null_counts: - non_null = pprint_thing(non_null_count[col]) - to_append += put_str(f"{non_null} non-null", lengths["null"]) - to_append += put_str(dtype, lengths["dtype"], spaces=0) - output.append(to_append) - - def non_verbose_repr(output): - output.append(columns._summary(name="Columns")) - - if verbose: - verbose_repr(output) - else: - non_verbose_repr(output) - - output.append(dtypes_line) - - if memory_usage: - deep = memory_usage == "deep" - mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() - mem_line = f"memory usage: {format_size(mem_usage_bytes)}" - - output.append(mem_line) - - output.append("") - buf.write("\n".join(output)) - - def insert( - self, - loc: int, - column: Hashable, - value: Scalar | AnyArrayLike, - allow_duplicates: bool | NoDefault = no_default, - ) -> None: - """ - Insert column into ``DataFrame`` at specified location. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - raise_if_native_pandas_objects(value) - if allow_duplicates is no_default: - allow_duplicates = False - if not allow_duplicates and column in self.columns: - raise ValueError(f"cannot insert {column}, already exists") - - if not isinstance(loc, int): - raise TypeError("loc must be int") - - # If columns labels are multilevel, we implement following behavior (this is - # name native pandas): - # Case 1: if 'column' is tuple it's length must be same as number of levels - # otherwise raise error. - # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in - # empty strings to match the length of column levels in self frame. - if self.columns.nlevels > 1: - if isinstance(column, tuple) and len(column) != self.columns.nlevels: - # same error as native pandas. - raise ValueError("Item must have length equal to number of levels.") - if not isinstance(column, tuple): - # Fill empty strings to match length of levels - suffix = [""] * (self.columns.nlevels - 1) - column = tuple([column] + suffix) - - # Dictionary keys are treated as index column and this should be joined with - # index of target dataframe. This behavior is similar to 'value' being DataFrame - # or Series, so we simply create Series from dict data here. - if isinstance(value, dict): - value = Series(value, name=column) - - if isinstance(value, DataFrame) or ( - isinstance(value, np.ndarray) and len(value.shape) > 1 - ): - # Supported numpy array shapes are - # 1. (N, ) -> Ex. [1, 2, 3] - # 2. (N, 1) -> Ex> [[1], [2], [3]] - if value.shape[1] != 1: - if isinstance(value, DataFrame): - # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin - raise ValueError( - f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." - ) - else: - raise ValueError( - f"Expected a 1D array, got an array with shape {value.shape}" - ) - # Change numpy array shape from (N, 1) to (N, ) - if isinstance(value, np.ndarray): - value = value.squeeze(axis=1) - - if ( - is_list_like(value) - and not isinstance(value, (Series, DataFrame)) - and len(value) != self.shape[0] - and not 0 == self.shape[0] # dataframe holds no rows - ): - raise ValueError( - "Length of values ({}) does not match length of index ({})".format( - len(value), len(self) - ) - ) - if not -len(self.columns) <= loc <= len(self.columns): - raise IndexError( - f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" - ) - elif loc < 0: - raise ValueError("unbounded slice") - - join_on_index = False - if isinstance(value, (Series, DataFrame)): - value = value._query_compiler - join_on_index = True - elif is_list_like(value): - value = Series(value, name=column)._query_compiler - - new_query_compiler = self._query_compiler.insert( - loc, column, value, join_on_index - ) - # In pandas, 'insert' operation is always inplace. - self._update_inplace(new_query_compiler=new_query_compiler) - - @dataframe_not_implemented() - def interpolate( - self, - method="linear", - axis=0, - limit=None, - inplace=False, - limit_direction: str | None = None, - limit_area=None, - downcast=None, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Fill NaN values using an interpolation method. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.interpolate, - method=method, - axis=axis, - limit=limit, - inplace=inplace, - limit_direction=limit_direction, - limit_area=limit_area, - downcast=downcast, - **kwargs, - ) - - def iterrows(self) -> Iterator[tuple[Hashable, Series]]: - """ - Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def iterrow_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - # Raise warning message since iterrows is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") - ) - - partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) - yield from partition_iterator - - def items(self): # noqa: D200 - """ - Iterate over (column name, ``Series``) pairs. - """ - - def items_builder(s): - """Return tuple of the given `s` parameter name and the parameter themselves.""" - return s.name, s - - partition_iterator = PartitionIterator(self, 1, items_builder) - yield from partition_iterator - - def itertuples( - self, index: bool = True, name: str | None = "Pandas" - ) -> Iterable[tuple[Any, ...]]: - """ - Iterate over ``DataFrame`` rows as ``namedtuple``-s. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - def itertuples_builder(s): - """Return the next namedtuple.""" - # s is the Series of values in the current row. - fields = [] # column names - data = [] # values under each column - - if index: - data.append(s.name) - fields.append("Index") - - # Fill column names and values. - fields.extend(list(self.columns)) - data.extend(s) - - if name is not None: - # Creating the namedtuple. - itertuple = collections.namedtuple(name, fields, rename=True) - return itertuple._make(data) - - # When the name is None, return a regular tuple. - return tuple(data) - - # Raise warning message since itertuples is very inefficient. - WarningMessage.single_warning( - DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") - ) - return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) - - def join( - self, - other: DataFrame | Series | Iterable[DataFrame | Series], - on: IndexLabel | None = None, - how: str = "left", - lsuffix: str = "", - rsuffix: str = "", - sort: bool = False, - validate: str | None = None, - ) -> DataFrame: - """ - Join columns of another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - for o in other if isinstance(other, list) else [other]: - raise_if_native_pandas_objects(o) - - # Similar to native pandas we implement 'join' using 'pd.merge' method. - # Following code is copied from native pandas (with few changes explained below) - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 - if isinstance(other, Series): - # Same error as native pandas. - if other.name is None: - raise ValueError("Other Series must have a name") - other = DataFrame(other) - elif is_list_like(other): - if any([isinstance(o, Series) and o.name is None for o in other]): - raise ValueError("Other Series must have a name") - - if isinstance(other, DataFrame): - if how == "cross": - return pd.merge( - self, - other, - how=how, - on=on, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - return pd.merge( - self, - other, - left_on=on, - how=how, - left_index=on is None, - right_index=True, - suffixes=(lsuffix, rsuffix), - sort=sort, - validate=validate, - ) - else: # List of DataFrame/Series - # Same error as native pandas. - if on is not None: - raise ValueError( - "Joining multiple DataFrames only supported for joining on index" - ) - - # Same error as native pandas. - if rsuffix or lsuffix: - raise ValueError( - "Suffixes not supported when joining multiple DataFrames" - ) - - # NOTE: These are not the differences between Snowpark pandas API and pandas behavior - # these are differences between native pandas join behavior when join - # frames have unique index or not. - - # In native pandas logic to join multiple DataFrames/Series is data - # dependent. Under the hood it will either use 'concat' or 'merge' API - # Case 1. If all objects being joined have unique index use 'concat' (axis=1) - # Case 2. Otherwise use 'merge' API by looping through objects left to right. - # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 - - # Even though concat (axis=1) and merge are very similar APIs they have - # some differences which leads to inconsistent behavior in native pandas. - # 1. Treatment of un-named Series - # Case #1: Un-named series is allowed in concat API. Objects are joined - # successfully by assigning a number as columns name (see 'concat' API - # documentation for details on treatment of un-named series). - # Case #2: It raises 'ValueError: Other Series must have a name' - - # 2. how='right' - # Case #1: 'concat' API doesn't support right join. It raises - # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' - # Case #2: Merges successfully. - - # 3. Joining frames with duplicate labels but no conflict with other frames - # Example: self = DataFrame(... columns=["A", "B"]) - # other = [DataFrame(... columns=["C", "C"])] - # Case #1: 'ValueError: Indexes have overlapping values' - # Case #2: Merged successfully. - - # In addition to this, native pandas implementation also leads to another - # type of inconsistency where left.join(other, ...) and - # left.join([other], ...) might behave differently for cases mentioned - # above. - # Example: - # import pandas as pd - # df = pd.DataFrame({"a": [4, 5]}) - # other = pd.Series([1, 2]) - # df.join([other]) # this is successful - # df.join(other) # this raises 'ValueError: Other Series must have a name' - - # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API - # to join multiple DataFrame/Series. So always follow the behavior - # documented as Case #2 above. - - joined = self - for frame in other: - if isinstance(frame, DataFrame): - overlapping_cols = set(joined.columns).intersection( - set(frame.columns) - ) - if len(overlapping_cols) > 0: - # Native pandas raises: 'Indexes have overlapping values' - # We differ slightly from native pandas message to make it more - # useful to users. - raise ValueError( - f"Join dataframes have overlapping column labels: {overlapping_cols}" - ) - joined = pd.merge( - joined, - frame, - how=how, - left_index=True, - right_index=True, - validate=validate, - sort=sort, - suffixes=(None, None), - ) - return joined - - def isna(self): - return super().isna() - - def isnull(self): - return super().isnull() - - @dataframe_not_implemented() - def isetitem(self, loc, value): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.isetitem, - loc=loc, - value=value, - ) - - def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("le", other, axis=axis, level=level) - - def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get less than comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("lt", other, axis=axis, level=level) - - def melt( - self, - id_vars=None, - value_vars=None, - var_name=None, - value_name="value", - col_level=None, - ignore_index=True, - ): # noqa: PR01, RT01, D200 - """ - Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if id_vars is None: - id_vars = [] - if not is_list_like(id_vars): - id_vars = [id_vars] - if value_vars is None: - # Behavior of Index.difference changed in 2.2.x - # https://github.com/pandas-dev/pandas/pull/55113 - # This change needs upstream to Modin: - # https://github.com/modin-project/modin/issues/7206 - value_vars = self.columns.drop(id_vars) - if var_name is None: - columns_name = self._query_compiler.get_index_name(axis=1) - var_name = columns_name if columns_name is not None else "variable" - return self.__constructor__( - query_compiler=self._query_compiler.melt( - id_vars=id_vars, - value_vars=value_vars, - var_name=var_name, - value_name=value_name, - col_level=col_level, - ignore_index=ignore_index, - ) - ) - - @dataframe_not_implemented() - def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 - """ - Return the memory usage of each column in bytes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - if index: - result = self._reduce_dimension( - self._query_compiler.memory_usage(index=False, deep=deep) - ) - index_value = self.index.memory_usage(deep=deep) - return pd.concat( - [Series(index_value, index=["Index"]), result] - ) # pragma: no cover - return super().memory_usage(index=index, deep=deep) - - def merge( - self, - right: DataFrame | Series, - how: str = "inner", - on: IndexLabel | None = None, - left_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - right_on: Hashable - | AnyArrayLike - | Sequence[Hashable | AnyArrayLike] - | None = None, - left_index: bool = False, - right_index: bool = False, - sort: bool = False, - suffixes: Suffixes = ("_x", "_y"), - copy: bool = True, - indicator: bool = False, - validate: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # Raise error if native pandas objects are passed. - raise_if_native_pandas_objects(right) - - if isinstance(right, Series) and right.name is None: - raise ValueError("Cannot merge a Series without a name") - if not isinstance(right, (Series, DataFrame)): - raise TypeError( - f"Can only merge Series or DataFrame objects, a {type(right)} was passed" - ) - - if isinstance(right, Series): - right_column_nlevels = ( - len(right.name) if isinstance(right.name, tuple) else 1 - ) - else: - right_column_nlevels = right.columns.nlevels - if self.columns.nlevels != right_column_nlevels: - # This is deprecated in native pandas. We raise explicit error for this. - raise ValueError( - "Can not merge objects with different column levels." - + f" ({self.columns.nlevels} levels on the left," - + f" {right_column_nlevels} on the right)" - ) - - # Merge empty native pandas dataframes for error checking. Otherwise, it will - # require a lot of logic to be written. This takes care of raising errors for - # following scenarios: - # 1. Only 'left_index' is set to True. - # 2. Only 'right_index is set to True. - # 3. Only 'left_on' is provided. - # 4. Only 'right_on' is provided. - # 5. 'on' and 'left_on' both are provided - # 6. 'on' and 'right_on' both are provided - # 7. 'on' and 'left_index' both are provided - # 8. 'on' and 'right_index' both are provided - # 9. 'left_on' and 'left_index' both are provided - # 10. 'right_on' and 'right_index' both are provided - # 11. Length mismatch between 'left_on' and 'right_on' - # 12. 'left_index' is not a bool - # 13. 'right_index' is not a bool - # 14. 'on' is not None and how='cross' - # 15. 'left_on' is not None and how='cross' - # 16. 'right_on' is not None and how='cross' - # 17. 'left_index' is True and how='cross' - # 18. 'right_index' is True and how='cross' - # 19. Unknown label in 'on', 'left_on' or 'right_on' - # 20. Provided 'suffixes' is not sufficient to resolve conflicts. - # 21. Merging on column with duplicate labels. - # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} - # 23. conflict with existing labels for array-like join key - # 24. 'indicator' argument is not bool or str - # 25. indicator column label conflicts with existing data labels - create_empty_native_pandas_frame(self).merge( - create_empty_native_pandas_frame(right), - on=on, - how=how, - left_on=replace_external_data_keys_with_empty_pandas_series(left_on), - right_on=replace_external_data_keys_with_empty_pandas_series(right_on), - left_index=left_index, - right_index=right_index, - suffixes=suffixes, - indicator=indicator, - ) - - return self.__constructor__( - query_compiler=self._query_compiler.merge( - right._query_compiler, - how=how, - on=on, - left_on=replace_external_data_keys_with_query_compiler(self, left_on), - right_on=replace_external_data_keys_with_query_compiler( - right, right_on - ), - left_index=left_index, - right_index=right_index, - sort=sort, - suffixes=suffixes, - copy=copy, - indicator=indicator, - validate=validate, - ) - ) - - def mod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `mod`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def mul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "mul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - multiply = mul - - def rmul( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rmul", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 - """ - Get not equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ne`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("ne", other, axis=axis, level=level) - - def nlargest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in descending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nlargest(n, columns, keep) - ) - - def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 - """ - Return the first `n` rows ordered by `columns` in ascending order. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.nsmallest( - n=n, columns=columns, keep=keep - ) - ) - - def unstack( - self, - level: int | str | list = -1, - fill_value: int | str | dict = None, - sort: bool = True, - ): - """ - Pivot a level of the (necessarily hierarchical) index labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This ensures that non-pandas MultiIndex objects are caught. - nlevels = self._query_compiler.nlevels() - is_multiindex = nlevels > 1 - - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.unstack( - level, fill_value, sort, is_series_input=False - ) - ) - - def pivot( - self, - *, - columns: Any, - index: Any | NoDefault = no_default, - values: Any | NoDefault = no_default, - ): - """ - Return reshaped DataFrame organized by given index / column values. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if index is no_default: - index = None # pragma: no cover - if values is no_default: - values = None - - # if values is not specified, it should be the remaining columns not in - # index or columns - if values is None: - values = list(self.columns) - if index is not None: - values = [v for v in values if v not in index] - if columns is not None: - values = [v for v in values if v not in columns] - - return self.__constructor__( - query_compiler=self._query_compiler.pivot( - index=index, columns=columns, values=values - ) - ) - - def pivot_table( - self, - values=None, - index=None, - columns=None, - aggfunc="mean", - fill_value=None, - margins=False, - dropna=True, - margins_name="All", - observed=False, - sort=True, - ): - """ - Create a spreadsheet-style pivot table as a ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - result = self.__constructor__( - query_compiler=self._query_compiler.pivot_table( - index=index, - values=values, - columns=columns, - aggfunc=aggfunc, - fill_value=fill_value, - margins=margins, - dropna=dropna, - margins_name=margins_name, - observed=observed, - sort=sort, - ) - ) - return result - - @dataframe_not_implemented() - @property - def plot( - self, - x=None, - y=None, - kind="line", - ax=None, - subplots=False, - sharex=None, - sharey=False, - layout=None, - figsize=None, - use_index=True, - title=None, - grid=None, - legend=True, - style=None, - logx=False, - logy=False, - loglog=False, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - rot=None, - fontsize=None, - colormap=None, - table=False, - yerr=None, - xerr=None, - secondary_y=False, - sort_columns=False, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Make plots of ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().plot - - def pow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "pow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - @dataframe_not_implemented() - def prod( - self, - axis=None, - skipna=True, - numeric_only=False, - min_count=0, - **kwargs, - ): # noqa: PR01, RT01, D200 - """ - Return the product of the values over the requested axis. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - validate_bool_kwarg(skipna, "skipna", none_allowed=False) - axis = self._get_axis_number(axis) - axis_to_apply = self.columns if axis else self.index - if ( - skipna is not False - and numeric_only is None - and min_count > len(axis_to_apply) - ): - new_index = self.columns if not axis else self.index - return Series( - [np.nan] * len(new_index), index=new_index, dtype=np.dtype("object") - ) - - data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True) - if min_count > 1: - return data._reduce_dimension( - data._query_compiler.prod_min_count( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - return data._reduce_dimension( - data._query_compiler.prod( - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - min_count=min_count, - **kwargs, - ) - ) - - product = prod - - def quantile( - self, - q: Scalar | ListLike = 0.5, - axis: Axis = 0, - numeric_only: bool = False, - interpolation: Literal[ - "linear", "lower", "higher", "midpoint", "nearest" - ] = "linear", - method: Literal["single", "table"] = "single", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().quantile( - q=q, - axis=axis, - numeric_only=numeric_only, - interpolation=interpolation, - method=method, - ) - - @dataframe_not_implemented() - def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 - """ - Query the columns of a ``DataFrame`` with a boolean expression. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._update_var_dicts_in_kwargs(expr, kwargs) - self._validate_eval_query(expr, **kwargs) - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.query(expr, **kwargs) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rename( - self, - mapper: Renamer | None = None, - *, - index: Renamer | None = None, - columns: Renamer | None = None, - axis: Axis | None = None, - copy: bool | None = None, - inplace: bool = False, - level: Level | None = None, - errors: IgnoreRaise = "ignore", - ) -> DataFrame | None: - """ - Alter axes labels. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if mapper is None and index is None and columns is None: - raise TypeError("must pass an index to rename") - - if index is not None or columns is not None: - if axis is not None: - raise TypeError( - "Cannot specify both 'axis' and any of 'index' or 'columns'" - ) - elif mapper is not None: - raise TypeError( - "Cannot specify both 'mapper' and any of 'index' or 'columns'" - ) - else: - # use the mapper argument - if axis and self._get_axis_number(axis) == 1: - columns = mapper - else: - index = mapper - - if copy is not None: - WarningMessage.ignored_argument( - operation="dataframe.rename", - argument="copy", - message="copy parameter has been ignored with Snowflake execution engine", - ) - - if isinstance(index, dict): - index = Series(index) - - new_qc = self._query_compiler.rename( - index_renamer=index, columns_renamer=columns, level=level, errors=errors - ) - return self._create_or_update_from_compiler( - new_query_compiler=new_qc, inplace=inplace - ) - - def reindex( - self, - labels=None, - index=None, - columns=None, - axis=None, - method=None, - copy=None, - level=None, - fill_value=np.nan, - limit=None, - tolerance=None, - ): # noqa: PR01, RT01, D200 - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - - axis = self._get_axis_number(axis) - if axis == 0 and labels is not None: - index = labels - elif labels is not None: - columns = labels - return super().reindex( - index=index, - columns=columns, - method=method, - copy=copy, - level=level, - fill_value=fill_value, - limit=limit, - tolerance=tolerance, - ) - - @dataframe_not_implemented() - def reindex_like( - self, - other, - method=None, - copy: bool | None = None, - limit=None, - tolerance=None, - ) -> DataFrame: # pragma: no cover - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if copy is None: - copy = True - # docs say "Same as calling .reindex(index=other.index, columns=other.columns,...).": - # https://pandas.pydata.org/pandas-docs/version/1.4/reference/api/pandas.DataFrame.reindex_like.html - return self.reindex( - index=other.index, - columns=other.columns, - method=method, - copy=copy, - limit=limit, - tolerance=tolerance, - ) - - def replace( - self, - to_replace=None, - value=no_default, - inplace: bool = False, - limit=None, - regex: bool = False, - method: str | NoDefault = no_default, - ): - """ - Replace values given in `to_replace` with `value`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - new_query_compiler = self._query_compiler.replace( - to_replace=to_replace, - value=value, - limit=limit, - regex=regex, - method=method, - ) - return self._create_or_update_from_compiler(new_query_compiler, inplace) - - def rfloordiv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `rfloordiv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rfloordiv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def radd( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get addition of ``DataFrame`` and `other`, element-wise (binary operator `radd`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "radd", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rmod( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `rmod`). - """ - return self._binary_op( - "rmod", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 - return super().round(decimals, args=args, **kwargs) - - def rpow( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rpow", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rsub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `rsub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rsub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - def rtruediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `rtruediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "rtruediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - rdiv = rtruediv - - def select_dtypes( - self, - include: ListLike | str | type | None = None, - exclude: ListLike | str | type | None = None, - ) -> DataFrame: - """ - Return a subset of the ``DataFrame``'s columns based on the column dtypes. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # This line defers argument validation to pandas, which will raise errors on our behalf in cases - # like if `include` and `exclude` are None, the same type is specified in both lists, or a string - # dtype (as opposed to object) is specified. - pandas.DataFrame().select_dtypes(include, exclude) - - if include and not is_list_like(include): - include = [include] - elif include is None: - include = [] - if exclude and not is_list_like(exclude): - exclude = [exclude] - elif exclude is None: - exclude = [] - - sel = tuple(map(set, (include, exclude))) - - # The width of the np.int_/float_ alias differs between Windows and other platforms, so - # we need to include a workaround. - # https://github.com/numpy/numpy/issues/9464 - # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 - def check_sized_number_infer_dtypes(dtype): - if (isinstance(dtype, str) and dtype == "int") or (dtype is int): - return [np.int32, np.int64] - elif dtype == "float" or dtype is float: - return [np.float64, np.float32] - else: - return [infer_dtype_from_object(dtype)] - - include, exclude = map( - lambda x: set( - itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) - ), - sel, - ) - # We need to index on column position rather than label in case of duplicates - include_these = pandas.Series(not bool(include), index=range(len(self.columns))) - exclude_these = pandas.Series(not bool(exclude), index=range(len(self.columns))) - - def is_dtype_instance_mapper(dtype): - return functools.partial(issubclass, dtype.type) - - for i, dtype in enumerate(self.dtypes): - if include: - include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) - if exclude: - exclude_these[i] = not any( - map(is_dtype_instance_mapper(dtype), exclude) - ) - - dtype_indexer = include_these & exclude_these - indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] - # We need to use iloc instead of drop in case of duplicate column names - return self.iloc[:, indicate] - - def shift( - self, - periods: int | Sequence[int] = 1, - freq=None, - axis: Axis = 0, - fill_value: Hashable = no_default, - suffix: str | None = None, - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().shift(periods, freq, axis, fill_value, suffix) - - def set_index( - self, - keys: IndexLabel - | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], - drop: bool = True, - append: bool = False, - inplace: bool = False, - verify_integrity: bool = False, - ) -> None | DataFrame: - """ - Set the ``DataFrame`` index using existing columns. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - inplace = validate_bool_kwarg(inplace, "inplace") - if not isinstance(keys, list): - keys = [keys] - - # make sure key is either hashable, index, or series - label_or_series = [] - - missing = [] - columns = self.columns.tolist() - for key in keys: - raise_if_native_pandas_objects(key) - if isinstance(key, pd.Series): - label_or_series.append(key._query_compiler) - elif isinstance(key, (np.ndarray, list, Iterator)): - label_or_series.append(pd.Series(key)._query_compiler) - elif isinstance(key, (pd.Index, pandas.MultiIndex)): - label_or_series += [ - s._query_compiler for s in self._to_series_list(key) - ] - else: - if not is_hashable(key): - raise TypeError( - f'The parameter "keys" may be a column key, one-dimensional array, or a list ' - f"containing only valid column keys and one-dimensional arrays. Received column " - f"of type {type(key)}" - ) - label_or_series.append(key) - found = key in columns - if columns.count(key) > 1: - raise ValueError(f"The column label '{key}' is not unique") - elif not found: - missing.append(key) - - if missing: - raise KeyError(f"None of {missing} are in the columns") - - new_query_compiler = self._query_compiler.set_index( - label_or_series, drop=drop, append=append - ) - - # TODO: SNOW-782633 improve this code once duplicate is supported - # this needs to pull all index which is inefficient - if verify_integrity and not new_query_compiler.index.is_unique: - duplicates = new_query_compiler.index[ - new_query_compiler.index.to_pandas().duplicated() - ].unique() - raise ValueError(f"Index has duplicate keys: {duplicates}") - - return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) - - sparse = CachedAccessor("sparse", SparseFrameAccessor) - - def squeeze(self, axis: Axis | None = None): - """ - Squeeze 1 dimensional axis objects into scalars. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) if axis is not None else None - len_columns = self._query_compiler.get_axis_len(1) - if axis == 1 and len_columns == 1: - return Series(query_compiler=self._query_compiler) - if axis in [0, None]: - # get_axis_len(0) results in a sql query to count number of rows in current - # dataframe. We should only compute len_index if axis is 0 or None. - len_index = len(self) - if axis is None and (len_columns == 1 or len_index == 1): - return Series(query_compiler=self._query_compiler).squeeze() - if axis == 0 and len_index == 1: - return Series(query_compiler=self.T._query_compiler) - return self.copy() - - def stack( - self, - level: int | str | list = -1, - dropna: bool | NoDefault = no_default, - sort: bool | NoDefault = no_default, - future_stack: bool = False, # ignored - ): - """ - Stack the prescribed level(s) from columns to index. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if future_stack is not False: - WarningMessage.ignored_argument( # pragma: no cover - operation="DataFrame.stack", - argument="future_stack", - message="future_stack parameter has been ignored with Snowflake execution engine", - ) - if dropna is NoDefault: - dropna = True # pragma: no cover - if sort is NoDefault: - sort = True # pragma: no cover - - # This ensures that non-pandas MultiIndex objects are caught. - is_multiindex = len(self.columns.names) > 1 - if not is_multiindex or ( - is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels - ): - return self._reduce_dimension( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - else: - return self.__constructor__( - query_compiler=self._query_compiler.stack(level, dropna, sort) - ) - - def sub( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `sub`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "sub", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - subtract = sub - - @dataframe_not_implemented() - def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to the binary Feather format. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas(pandas.DataFrame.to_feather, path, **kwargs) - - @dataframe_not_implemented() - def to_gbq( - self, - destination_table, - project_id=None, - chunksize=None, - reauth=False, - if_exists="fail", - auth_local_webserver=True, - table_schema=None, - location=None, - progress_bar=True, - credentials=None, - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Write a ``DataFrame`` to a Google BigQuery table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functionsf - return self._default_to_pandas( - pandas.DataFrame.to_gbq, - destination_table, - project_id=project_id, - chunksize=chunksize, - reauth=reauth, - if_exists=if_exists, - auth_local_webserver=auth_local_webserver, - table_schema=table_schema, - location=location, - progress_bar=progress_bar, - credentials=credentials, - ) - - @dataframe_not_implemented() - def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_orc, - path=path, - engine=engine, - index=index, - engine_kwargs=engine_kwargs, - ) - - @dataframe_not_implemented() - def to_html( - self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - bold_rows=True, - classes=None, - escape=True, - notebook=False, - border=None, - table_id=None, - render_links=False, - encoding=None, - ): # noqa: PR01, RT01, D200 - """ - Render a ``DataFrame`` as an HTML table. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_html, - buf=buf, - columns=columns, - col_space=col_space, - header=header, - index=index, - na_rep=na_rep, - formatters=formatters, - float_format=float_format, - sparsify=sparsify, - index_names=index_names, - justify=justify, - max_rows=max_rows, - max_cols=max_cols, - show_dimensions=show_dimensions, - decimal=decimal, - bold_rows=bold_rows, - classes=classes, - escape=escape, - notebook=notebook, - border=border, - table_id=table_id, - render_links=render_links, - encoding=None, - ) - - @dataframe_not_implemented() - def to_parquet( - self, - path=None, - engine="auto", - compression="snappy", - index=None, - partition_cols=None, - storage_options: StorageOptions = None, - **kwargs, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - from snowflake.snowpark.modin.pandas.dispatching.factories.dispatcher import ( - FactoryDispatcher, - ) - - return FactoryDispatcher.to_parquet( - self._query_compiler, - path=path, - engine=engine, - compression=compression, - index=index, - partition_cols=partition_cols, - storage_options=storage_options, - **kwargs, - ) - - @dataframe_not_implemented() - def to_period( - self, freq=None, axis=0, copy=True - ): # pragma: no cover # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` from ``DatetimeIndex`` to ``PeriodIndex``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_period(freq=freq, axis=axis, copy=copy) - - @dataframe_not_implemented() - def to_records( - self, index=True, column_dtypes=None, index_dtypes=None - ): # noqa: PR01, RT01, D200 - """ - Convert ``DataFrame`` to a NumPy record array. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_records, - index=index, - column_dtypes=column_dtypes, - index_dtypes=index_dtypes, - ) - - @dataframe_not_implemented() - def to_stata( - self, - path: FilePath | WriteBuffer[bytes], - convert_dates: dict[Hashable, str] | None = None, - write_index: bool = True, - byteorder: str | None = None, - time_stamp: datetime.datetime | None = None, - data_label: str | None = None, - variable_labels: dict[Hashable, str] | None = None, - version: int | None = 114, - convert_strl: Sequence[Hashable] | None = None, - compression: CompressionOptions = "infer", - storage_options: StorageOptions = None, - *, - value_labels: dict[Hashable, dict[float | int, str]] | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.to_stata, - path, - convert_dates=convert_dates, - write_index=write_index, - byteorder=byteorder, - time_stamp=time_stamp, - data_label=data_label, - variable_labels=variable_labels, - version=version, - convert_strl=convert_strl, - compression=compression, - storage_options=storage_options, - value_labels=value_labels, - ) - - @dataframe_not_implemented() - def to_xml( - self, - path_or_buffer=None, - index=True, - root_name="data", - row_name="row", - na_rep=None, - attr_cols=None, - elem_cols=None, - namespaces=None, - prefix=None, - encoding="utf-8", - xml_declaration=True, - pretty_print=True, - parser="lxml", - stylesheet=None, - compression="infer", - storage_options=None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.__constructor__( - query_compiler=self._query_compiler.default_to_pandas( - pandas.DataFrame.to_xml, - path_or_buffer=path_or_buffer, - index=index, - root_name=root_name, - row_name=row_name, - na_rep=na_rep, - attr_cols=attr_cols, - elem_cols=elem_cols, - namespaces=namespaces, - prefix=prefix, - encoding=encoding, - xml_declaration=xml_declaration, - pretty_print=pretty_print, - parser=parser, - stylesheet=stylesheet, - compression=compression, - storage_options=storage_options, - ) - ) - - def to_dict( - self, - orient: Literal[ - "dict", "list", "series", "split", "tight", "records", "index" - ] = "dict", - into: type[dict] = dict, - ) -> dict | list[dict]: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._to_pandas().to_dict(orient=orient, into=into) - - def to_timestamp( - self, freq=None, how="start", axis=0, copy=True - ): # noqa: PR01, RT01, D200 - """ - Cast to DatetimeIndex of timestamps, at *beginning* of period. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().to_timestamp(freq=freq, how=how, axis=axis, copy=copy) - - def truediv( - self, other, axis="columns", level=None, fill_value=None - ): # noqa: PR01, RT01, D200 - """ - Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `truediv`). - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op( - "truediv", - other, - axis=axis, - level=level, - fill_value=fill_value, - ) - - div = divide = truediv - - def update( - self, other, join="left", overwrite=True, filter_func=None, errors="ignore" - ): # noqa: PR01, RT01, D200 - """ - Modify in place using non-NA values from another ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not isinstance(other, DataFrame): - other = self.__constructor__(other) - query_compiler = self._query_compiler.df_update( - other._query_compiler, - join=join, - overwrite=overwrite, - filter_func=filter_func, - errors=errors, - ) - self._update_inplace(new_query_compiler=query_compiler) - - def diff( - self, - periods: int = 1, - axis: Axis = 0, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().diff( - periods=periods, - axis=axis, - ) - - def drop( - self, - labels: IndexLabel = None, - axis: Axis = 0, - index: IndexLabel = None, - columns: IndexLabel = None, - level: Level = None, - inplace: bool = False, - errors: IgnoreRaise = "raise", - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().drop( - labels=labels, - axis=axis, - index=index, - columns=columns, - level=level, - inplace=inplace, - errors=errors, - ) - - def value_counts( - self, - subset: Sequence[Hashable] | None = None, - normalize: bool = False, - sort: bool = True, - ascending: bool = False, - dropna: bool = True, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series( - query_compiler=self._query_compiler.value_counts( - subset=subset, - normalize=normalize, - sort=sort, - ascending=ascending, - dropna=dropna, - ), - name="proportion" if normalize else "count", - ) - - def mask( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.mask requires an axis parameter (0 or 1) when given a Series" - ) - - return super().mask( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - def where( - self, - cond: DataFrame | Series | Callable | AnyArrayLike, - other: DataFrame | Series | Callable | Scalar | None = np.nan, - *, - inplace: bool = False, - axis: Axis | None = None, - level: Level | None = None, - ): - """ - Replace values where the condition is False. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(other, Series) and axis is None: - raise ValueError( - "df.where requires an axis parameter (0 or 1) when given a Series" - ) - - return super().where( - cond, - other=other, - inplace=inplace, - axis=axis, - level=level, - ) - - @dataframe_not_implemented() - def xs(self, key, axis=0, level=None, drop_level=True): # noqa: PR01, RT01, D200 - """ - Return cross-section from the ``DataFrame``. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._default_to_pandas( - pandas.DataFrame.xs, key, axis=axis, level=level, drop_level=drop_level - ) - - def set_axis( - self, - labels: IndexLabel, - *, - axis: Axis = 0, - copy: bool | NoDefault = no_default, # ignored - ): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if not is_scalar(axis): - raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") - return super().set_axis( - labels=labels, - # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. - axis=pandas.DataFrame._get_axis_name(axis), - copy=copy, - ) - - def __getattr__(self, key): - """ - Return item identified by `key`. - - Parameters - ---------- - key : hashable - Key to get. - - Returns - ------- - Any - - Notes - ----- - First try to use `__getattribute__` method. If it fails - try to get `key` from ``DataFrame`` fields. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - try: - return object.__getattribute__(self, key) - except AttributeError as err: - if key not in _ATTRS_NO_LOOKUP and key in self.columns: - return self[key] - raise err - - def __setattr__(self, key, value): - """ - Set attribute `value` identified by `key`. - - Parameters - ---------- - key : hashable - Key to set. - value : Any - Value to set. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # While we let users assign to a column labeled "x" with "df.x" , there - # are some attributes that we should assume are NOT column names and - # therefore should follow the default Python object assignment - # behavior. These are: - # - anything in self.__dict__. This includes any attributes that the - # user has added to the dataframe with, e.g., `df.c = 3`, and - # any attribute that Modin has added to the frame, e.g. - # `_query_compiler` and `_siblings` - # - `_query_compiler`, which Modin initializes before it appears in - # __dict__ - # - `_siblings`, which Modin initializes before it appears in __dict__ - # - `_cache`, which pandas.cache_readonly uses to cache properties - # before it appears in __dict__. - if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__: - pass - elif key in self and key not in dir(self): - self.__setitem__(key, value) - # Note: return immediately so we don't keep this `key` as dataframe state. - # `__getattr__` will return the columns not present in `dir(self)`, so we do not need - # to manually track this state in the `dir`. - return - elif is_list_like(value) and key not in ["index", "columns"]: - WarningMessage.single_warning( - SET_DATAFRAME_ATTRIBUTE_WARNING - ) # pragma: no cover - object.__setattr__(self, key, value) - - def __setitem__(self, key: Any, value: Any): - """ - Set attribute `value` identified by `key`. - - Args: - key: Key to set - value: Value to set - - Note: - In the case where value is any list like or array, pandas checks the array length against the number of rows - of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw - a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if - the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use - enlargement filling with the last value in the array. - - Returns: - None - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - key = apply_if_callable(key, self) - if isinstance(key, DataFrame) or ( - isinstance(key, np.ndarray) and len(key.shape) == 2 - ): - # This case uses mask's codepath to perform the set, but - # we need to duplicate the code here since we are passing - # an additional kwarg `cond_fillna_with_true` to the QC here. - # We need this additional kwarg, since if df.shape - # and key.shape do not align (i.e. df has more rows), - # mask's codepath would mask the additional rows in df - # while for setitem, we need to keep the original values. - if not isinstance(key, DataFrame): - if key.dtype != bool: - raise TypeError( - "Must pass DataFrame or 2-d ndarray with boolean values only" - ) - key = DataFrame(key) - key._query_compiler._shape_hint = "array" - - if value is not None: - value = apply_if_callable(value, self) - - if isinstance(value, np.ndarray): - value = DataFrame(value) - value._query_compiler._shape_hint = "array" - elif isinstance(value, pd.Series): - # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this - # error instead, since it is more descriptive. - raise ValueError( - "setitem with a 2D key does not support Series values." - ) - - if isinstance(value, BasePandasDataset): - value = value._query_compiler - - query_compiler = self._query_compiler.mask( - cond=key._query_compiler, - other=value, - axis=None, - level=None, - cond_fillna_with_true=True, - ) - - return self._create_or_update_from_compiler(query_compiler, inplace=True) - - # Error Checking: - if (isinstance(key, pd.Series) or is_list_like(key)) and ( - isinstance(value, range) - ): - raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) - elif isinstance(value, slice): - # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. - raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) - - # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column - # key. - index, columns = slice(None), key - index_is_bool_indexer = False - if isinstance(key, slice): - if is_integer(key.start) and is_integer(key.stop): - # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as - # df.iloc[1:2, :] = val - self.iloc[key] = value - return - index, columns = key, slice(None) - elif isinstance(key, pd.Series): - if is_bool_dtype(key.dtype): - index, columns = key, slice(None) - index_is_bool_indexer = True - elif is_bool_indexer(key): - index, columns = pd.Series(key), slice(None) - index_is_bool_indexer = True - - # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case - # we have to explicitly set matching_item_columns_by_label to False for setitem. - index = index._query_compiler if isinstance(index, BasePandasDataset) else index - columns = ( - columns._query_compiler - if isinstance(columns, BasePandasDataset) - else columns - ) - from .indexing import is_2d_array - - matching_item_rows_by_label = not is_2d_array(value) - if is_2d_array(value): - value = DataFrame(value) - item = value._query_compiler if isinstance(value, BasePandasDataset) else value - new_qc = self._query_compiler.set_2d_labels( - index, - columns, - item, - # setitem always matches item by position - matching_item_columns_by_label=False, - matching_item_rows_by_label=matching_item_rows_by_label, - index_is_bool_indexer=index_is_bool_indexer, - # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling - # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the - # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have - # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns - # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", - # "X", "X". - deduplicate_columns=True, - ) - return self._update_inplace(new_query_compiler=new_qc) - - def abs(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().abs() - - def __and__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__and__", other, axis=1) - - def __rand__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__rand__", other, axis=1) - - def __or__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__or__", other, axis=1) - - def __ror__(self, other): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._binary_op("__ror__", other, axis=1) - - def __neg__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().__neg__() - - def __iter__(self): - """ - Iterate over info axis. - - Returns - ------- - iterable - Iterator of the columns names. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return iter(self.columns) - - def __contains__(self, key): - """ - Check if `key` in the ``DataFrame.columns``. - - Parameters - ---------- - key : hashable - Key to check the presence in the columns. - - Returns - ------- - bool - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self.columns.__contains__(key) - - def __round__(self, decimals=0): - """ - Round each value in a ``DataFrame`` to the given number of decimals. - - Parameters - ---------- - decimals : int, default: 0 - Number of decimal places to round to. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return super().round(decimals) - - @dataframe_not_implemented() - def __delitem__(self, key): - """ - Delete item identified by `key` label. - - Parameters - ---------- - key : hashable - Key to delete. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if key not in self: - raise KeyError(key) - self._update_inplace(new_query_compiler=self._query_compiler.delitem(key)) - - __add__ = add - __iadd__ = add # pragma: no cover - __radd__ = radd - __mul__ = mul - __imul__ = mul # pragma: no cover - __rmul__ = rmul - __pow__ = pow - __ipow__ = pow # pragma: no cover - __rpow__ = rpow - __sub__ = sub - __isub__ = sub # pragma: no cover - __rsub__ = rsub - __floordiv__ = floordiv - __ifloordiv__ = floordiv # pragma: no cover - __rfloordiv__ = rfloordiv - __truediv__ = truediv - __itruediv__ = truediv # pragma: no cover - __rtruediv__ = rtruediv - __mod__ = mod - __imod__ = mod # pragma: no cover - __rmod__ = rmod - __rdiv__ = rdiv - - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): - """ - Get a Modin DataFrame that implements the dataframe exchange protocol. - - See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. - - Parameters - ---------- - nan_as_null : bool, default: False - A keyword intended for the consumer to tell the producer - to overwrite null values in the data with ``NaN`` (or ``NaT``). - This currently has no effect; once support for nullable extension - dtypes is added, this value should be propagated to columns. - allow_copy : bool, default: True - A keyword that defines whether or not the library is allowed - to make a copy of the data. For example, copying data would be necessary - if a library supports strided buffers, given that this protocol - specifies contiguous buffers. Currently, if the flag is set to ``False`` - and a copy is needed, a ``RuntimeError`` will be raised. - - Returns - ------- - ProtocolDataframe - A dataframe object following the dataframe protocol specification. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - ErrorMessage.not_implemented( - "Snowpark pandas does not support the DataFrame interchange " - + "protocol method `__dataframe__`. To use Snowpark pandas " - + "DataFrames with third-party libraries that try to call the " - + "`__dataframe__` method, please convert this Snowpark pandas " - + "DataFrame to pandas with `to_pandas()`." - ) - - return self._query_compiler.to_dataframe( - nan_as_null=nan_as_null, allow_copy=allow_copy - ) - - @dataframe_not_implemented() - @property - def attrs(self): # noqa: RT01, D200 - """ - Return dictionary of global attributes of this dataset. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def attrs(df): - return df.attrs - - return self._default_to_pandas(attrs) - - @dataframe_not_implemented() - @property - def style(self): # noqa: RT01, D200 - """ - Return a Styler object. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - def style(df): - """Define __name__ attr because properties do not have it.""" - return df.style - - return self._default_to_pandas(style) - - def isin( - self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] - ) -> DataFrame: - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(values, dict): - return super().isin(values) - elif isinstance(values, Series): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not values.index.is_unique: - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - elif isinstance(values, DataFrame): - # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. - # if not (values.columns.is_unique and values.index.is_unique): - # raise ValueError("cannot compute isin with a duplicate axis.") - return self.__constructor__( - query_compiler=self._query_compiler.isin(values._query_compiler) - ) - else: - if not is_list_like(values): - # throw pandas compatible error - raise TypeError( - "only list-like or dict-like objects are allowed " - f"to be passed to {self.__class__.__name__}.isin(), " - f"you passed a '{type(values).__name__}'" - ) - return super().isin(values) - - def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): - """ - Return or update a ``DataFrame`` with given `new_query_compiler`. - - Parameters - ---------- - new_query_compiler : PandasQueryCompiler - QueryCompiler to use to manage the data. - inplace : bool, default: False - Whether or not to perform update or creation inplace. - - Returns - ------- - DataFrame or None - None if update was done, ``DataFrame`` otherwise. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - assert ( - isinstance(new_query_compiler, type(self._query_compiler)) - or type(new_query_compiler) in self._query_compiler.__class__.__bases__ - ), f"Invalid Query Compiler object: {type(new_query_compiler)}" - if not inplace: - return self.__constructor__(query_compiler=new_query_compiler) - else: - self._update_inplace(new_query_compiler=new_query_compiler) - - def _get_numeric_data(self, axis: int): - """ - Grab only numeric data from ``DataFrame``. - - Parameters - ---------- - axis : {0, 1} - Axis to inspect on having numeric types only. - - Returns - ------- - DataFrame - ``DataFrame`` with numeric data. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # pandas ignores `numeric_only` if `axis` is 1, but we do have to drop - # non-numeric columns if `axis` is 0. - if axis != 0: - return self - return self.drop( - columns=[ - i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) - ] - ) - - def _validate_dtypes(self, numeric_only=False): - """ - Check that all the dtypes are the same. - - Parameters - ---------- - numeric_only : bool, default: False - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - dtype = self.dtypes[0] - for t in self.dtypes: - if numeric_only and not is_numeric_dtype(t): - raise TypeError(f"{t} is not a numeric data type") - elif not numeric_only and t != dtype: - raise TypeError(f"Cannot compare type '{t}' with type '{dtype}'") - - def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False): - """ - Validate data dtype for `sum`, `prod` and `mean` methods. - - Parameters - ---------- - axis : {0, 1} - Axis to validate over. - numeric_only : bool - Whether or not to allow only numeric data. - If True and non-numeric data is found, exception - will be raised. - ignore_axis : bool, default: False - Whether or not to ignore `axis` parameter. - - Returns - ------- - DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - # We cannot add datetime types, so if we are summing a column with - # dtype datetime64 and cannot ignore non-numeric types, we must throw a - # TypeError. - if ( - not axis - and numeric_only is False - and any(dtype == np.dtype("datetime64[ns]") for dtype in self.dtypes) - ): - raise TypeError("Cannot add Timestamp Types") - - # If our DataFrame has both numeric and non-numeric dtypes then - # operations between these types do not make sense and we must raise a - # TypeError. The exception to this rule is when there are datetime and - # timedelta objects, in which case we proceed with the comparison - # without ignoring any non-numeric types. We must check explicitly if - # numeric_only is False because if it is None, it will default to True - # if the operation fails with mixed dtypes. - if ( - (axis or ignore_axis) - and numeric_only is False - and np.unique([is_numeric_dtype(dtype) for dtype in self.dtypes]).size == 2 - ): - # check if there are columns with dtypes datetime or timedelta - if all( - dtype != np.dtype("datetime64[ns]") - and dtype != np.dtype("timedelta64[ns]") - for dtype in self.dtypes - ): - raise TypeError("Cannot operate on Numeric and Non-Numeric Types") - - return self._get_numeric_data(axis) if numeric_only else self - - def _to_pandas( - self, - *, - statement_params: dict[str, str] | None = None, - **kwargs: Any, - ) -> pandas.DataFrame: - """ - Convert Snowpark pandas DataFrame to pandas DataFrame - - Args: - statement_params: Dictionary of statement level parameters to be set while executing this action. - - Returns: - pandas DataFrame - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._query_compiler.to_pandas( - statement_params=statement_params, **kwargs - ) - - def _validate_eval_query(self, expr, **kwargs): - """ - Validate the arguments of ``eval`` and ``query`` functions. - - Parameters - ---------- - expr : str - The expression to evaluate. This string cannot contain any - Python statements, only Python expressions. - **kwargs : dict - Optional arguments of ``eval`` and ``query`` functions. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - if isinstance(expr, str) and expr == "": - raise ValueError("expr cannot be an empty string") - - if isinstance(expr, str) and "not" in expr: - if "parser" in kwargs and kwargs["parser"] == "python": - ErrorMessage.not_implemented( # pragma: no cover - "Snowpark pandas does not yet support 'not' in the " - + "expression for the methods `DataFrame.eval` or " - + "`DataFrame.query`" - ) - - def _reduce_dimension(self, query_compiler): - """ - Reduce the dimension of data from the `query_compiler`. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to retrieve the data. - - Returns - ------- - Series - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return Series(query_compiler=query_compiler) - - def _set_axis_name(self, name, axis=0, inplace=False): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - axis = self._get_axis_number(axis) - renamed = self if inplace else self.copy() - if axis == 0: - renamed.index = renamed.index.set_names(name) - else: - renamed.columns = renamed.columns.set_names(name) - if not inplace: - return renamed - - def _to_datetime(self, **kwargs): - """ - Convert `self` to datetime. - - Parameters - ---------- - **kwargs : dict - Optional arguments to use during query compiler's - `to_datetime` invocation. - - Returns - ------- - Series of datetime64 dtype - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return self._reduce_dimension( - query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) - ) - - # Persistance support methods - BEGIN - @classmethod - def _inflate_light(cls, query_compiler): - """ - Re-creates the object from previously-serialized lightweight representation. - - The method is used for faster but not disk-storable persistence. - - Parameters - ---------- - query_compiler : BaseQueryCompiler - Query compiler to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `query_compiler`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(query_compiler=query_compiler) - - @classmethod - def _inflate_full(cls, pandas_df): - """ - Re-creates the object from previously-serialized disk-storable representation. - - Parameters - ---------- - pandas_df : pandas.DataFrame - Data to use for object re-creation. - - Returns - ------- - DataFrame - New ``DataFrame`` based on the `pandas_df`. - """ - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - return cls(data=from_pandas(pandas_df)) - - @dataframe_not_implemented() - def __reduce__(self): - # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions - self._query_compiler.finalize() - # if PersistentPickle.get(): - # return self._inflate_full, (self._to_pandas(),) - return self._inflate_light, (self._query_compiler,) - - # Persistance support methods - END diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index 2ca9d8e5b83..5024d0618ac 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -31,7 +31,7 @@ import numpy as np import pandas import pandas.core.common as common -from modin.pandas import Series +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import IntervalIndex, NaT, Timedelta, Timestamp from pandas._libs import NaTType, lib @@ -65,7 +65,6 @@ # add this line to make doctests runnable from snowflake.snowpark.modin import pandas as pd # noqa: F401 -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import ( is_scalar, raise_if_native_pandas_objects, @@ -92,10 +91,9 @@ # linking to `snowflake.snowpark.DataFrame`, we need to explicitly # qualify return types in this file with `modin.pandas.DataFrame`. # SNOW-1233342: investigate how to fix these links without using absolute paths + import modin from modin.core.storage_formats import BaseQueryCompiler # pragma: no cover - import snowflake # pragma: no cover - _logger = getLogger(__name__) VALID_DATE_TYPE = Union[ @@ -137,8 +135,8 @@ def notna(obj): # noqa: PR01, RT01, D200 @snowpark_pandas_telemetry_standalone_function_decorator def merge( - left: snowflake.snowpark.modin.pandas.DataFrame | Series, - right: snowflake.snowpark.modin.pandas.DataFrame | Series, + left: modin.pandas.DataFrame | Series, + right: modin.pandas.DataFrame | Series, how: str | None = "inner", on: IndexLabel | None = None, left_on: None @@ -414,7 +412,7 @@ def merge_asof( tolerance: int | Timedelta | None = None, allow_exact_matches: bool = True, direction: str = "backward", -) -> snowflake.snowpark.modin.pandas.DataFrame: +) -> modin.pandas.DataFrame: """ Perform a merge by key distance. @@ -1105,8 +1103,8 @@ def value_counts( @snowpark_pandas_telemetry_standalone_function_decorator def concat( objs: ( - Iterable[snowflake.snowpark.modin.pandas.DataFrame | Series] - | Mapping[Hashable, snowflake.snowpark.modin.pandas.DataFrame | Series] + Iterable[modin.pandas.DataFrame | Series] + | Mapping[Hashable, modin.pandas.DataFrame | Series] ), axis: Axis = 0, join: str = "outer", @@ -1117,7 +1115,7 @@ def concat( verify_integrity: bool = False, sort: bool = False, copy: bool = True, -) -> snowflake.snowpark.modin.pandas.DataFrame | Series: +) -> modin.pandas.DataFrame | Series: """ Concatenate pandas objects along a particular axis. @@ -1490,7 +1488,7 @@ def concat( def to_datetime( arg: DatetimeScalarOrArrayConvertible | DictConvertible - | snowflake.snowpark.modin.pandas.DataFrame + | modin.pandas.DataFrame | Series, errors: DateTimeErrorChoices = "raise", dayfirst: bool = False, diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index c672f04da63..5da10d9b7a6 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -45,6 +45,7 @@ import pandas from modin.pandas import Series from modin.pandas.base import BasePandasDataset +from modin.pandas.dataframe import DataFrame from pandas._libs.tslibs import Resolution, parsing from pandas._typing import AnyArrayLike, Scalar from pandas.api.types import is_bool, is_list_like @@ -61,7 +62,6 @@ import snowflake.snowpark.modin.pandas as pd import snowflake.snowpark.modin.pandas.utils as frontend_utils -from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin._internal.indexing_utils import ( MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE, diff --git a/src/snowflake/snowpark/modin/pandas/io.py b/src/snowflake/snowpark/modin/pandas/io.py index 25959212a18..b92e8ee3582 100644 --- a/src/snowflake/snowpark/modin/pandas/io.py +++ b/src/snowflake/snowpark/modin/pandas/io.py @@ -92,7 +92,7 @@ # below logic is to handle circular imports without errors if TYPE_CHECKING: # pragma: no cover - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame # TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available @@ -106,7 +106,7 @@ class ModinObjects: def DataFrame(cls): """Get ``modin.pandas.DataFrame`` class.""" if cls._dataframe is None: - from .dataframe import DataFrame + from modin.pandas.dataframe import DataFrame cls._dataframe = DataFrame return cls._dataframe diff --git a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py index 3529355b81b..ee782f3cdf3 100644 --- a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py +++ b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py @@ -5,10 +5,9 @@ from collections.abc import Iterator from typing import Any, Callable +import modin.pandas.dataframe as DataFrame import pandas -import snowflake.snowpark.modin.pandas.dataframe as DataFrame - PARTITION_SIZE = 4096 diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index 3986e3d52a9..a48f16992d4 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -78,7 +78,7 @@ def from_non_pandas(df, index, columns, dtype): new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype) if new_qc is not None: - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=new_qc) return new_qc @@ -99,7 +99,7 @@ def from_pandas(df): A new Modin DataFrame object. """ # from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) @@ -118,10 +118,11 @@ def from_arrow(at): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_arrow(at)) @@ -142,10 +143,11 @@ def from_dataframe(df): DataFrame A new Modin DataFrame object. """ + from modin.pandas import DataFrame + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( FactoryDispatcher, ) - from snowflake.snowpark.modin.pandas import DataFrame return DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df)) @@ -226,7 +228,7 @@ def from_modin_frame_to_mi(df, sortorder=None, names=None): pandas.MultiIndex The pandas.MultiIndex representation of the given DataFrame. """ - from snowflake.snowpark.modin.pandas import DataFrame + from modin.pandas import DataFrame if isinstance(df, DataFrame): df = df._to_pandas() diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index d3ac525572a..eceb9ca7d7f 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -69,6 +69,7 @@ inherit_modules = [ (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset), + (docstrings.dataframe.DataFrame, modin.pandas.dataframe.DataFrame), (docstrings.series.Series, modin.pandas.series.Series), (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods), ( @@ -90,17 +91,3 @@ snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = ( False ) - - -# TODO: SNOW-1504302: Modin upgrade - use Snowpark pandas DataFrame for isocalendar -# OSS Modin's DatetimeProperties frontend class wraps the returned query compiler with `modin.pandas.DataFrame`. -# Since we currently replace `pd.DataFrame` with our own Snowpark pandas DataFrame object, this causes errors -# since OSS Modin explicitly imports its own DataFrame class here. This override can be removed once the frontend -# DataFrame class is removed from our codebase. -def isocalendar(self): # type: ignore - from snowflake.snowpark.modin.pandas import DataFrame - - return DataFrame(query_compiler=self._query_compiler.dt_isocalendar()) - - -modin.pandas.series_utils.DatetimeProperties.isocalendar = isocalendar diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 01ccad8f430..0005df924db 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Iterable from functools import partial from inspect import getmembers -from types import BuiltinFunctionType +from types import BuiltinFunctionType, MappingProxyType from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union import numpy as np @@ -56,6 +56,7 @@ stddev, stddev_pop, sum as sum_, + trunc, var_pop, variance, when, @@ -65,6 +66,9 @@ OrderedDataFrame, OrderingColumn, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.utils import ( from_pandas_label, pandas_lit, @@ -85,7 +89,7 @@ } -def array_agg_keepna( +def _array_agg_keepna( column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn] ) -> Column: """ @@ -239,62 +243,63 @@ def _columns_coalescing_idxmax_idxmin_helper( ) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding snowflake builtin aggregation function for axis=0. If any change -# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and -# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly. -SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": count, - "mean": mean, - "min": min_, - "max": max_, - "idxmax": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmax" - ), - "idxmin": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmin" - ), - "sum": sum_, - "median": median, - "skew": skew, - "std": stddev, - "var": variance, - "all": builtin("booland_agg"), - "any": builtin("boolor_agg"), - np.max: max_, - np.min: min_, - np.sum: sum_, - np.mean: mean, - np.median: median, - np.std: stddev, - np.var: variance, - "array_agg": array_agg, - "quantile": column_quantile, - "nunique": count_distinct, -} -GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = ( - "min", - "max", - "sum", - "mean", - "median", - "std", - np.max, - np.min, - np.sum, - np.mean, - np.median, - np.std, -) -GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = ( - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", -) +class _SnowparkPandasAggregation(NamedTuple): + """ + A representation of a Snowpark pandas aggregation. + + This structure gives us a common representation for an aggregation that may + have multiple aliases, like "sum" and np.sum. + """ + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool + + # This callable takes a single Snowpark column as input and aggregates the + # column on axis=0. If None, Snowpark pandas does not support this + # aggregation on axis=0. + axis_0_aggregation: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=True, i.e. not including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=True. + axis_1_aggregation_skipna: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=False, i.e. including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=False. + axis_1_aggregation_keepna: Optional[Callable] = None + + +class SnowflakeAggFunc(NamedTuple): + """ + A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType. + """ + + # The aggregation function in Snowpark. + # For aggregation on axis=0, this field should take a single Snowpark + # column and return the aggregated column. + # For aggregation on axis=1, this field should take an arbitrary number + # of Snowpark columns and return the aggregated column. + snowpark_aggregation: Callable + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool class AggFuncWithLabel(NamedTuple): @@ -413,35 +418,143 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: return sum(builtin("zeroifnull")(col) for col in cols) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation -# function may either be a builtin aggregation function, or a function taking in *arg columns -# that then calls the appropriate builtin aggregations. -SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": _columns_count, - "sum": _columns_coalescing_sum, - np.sum: _columns_coalescing_sum, - "min": _columns_coalescing_min, - "max": _columns_coalescing_max, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - np.min: _columns_coalescing_min, - np.max: _columns_coalescing_max, -} +def _create_pandas_to_snowpark_pandas_aggregation_map( + pandas_functions: Iterable[AggFuncTypeBase], + snowpark_pandas_aggregation: _SnowparkPandasAggregation, +) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]: + """ + Create a map from the given pandas functions to the given _SnowparkPandasAggregation. -# These functions are called instead if skipna=False -SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "min": least, - "max": greatest, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark - # sum_, since Snowpark sum_ gets the sum of all rows within a single column. - "sum": lambda *cols: sum(cols), - np.sum: lambda *cols: sum(cols), - np.min: least, - np.max: greatest, -} + Args; + pandas_functions: The pandas functions that map to the given aggregation. + snowpark_pandas_aggregation: The aggregation to map to + + Returns: + The map. + """ + return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions}) + + +# Map between the pandas input aggregation function (str or numpy function) and +# _SnowparkPandasAggregation representing information about applying the +# aggregation in Snowpark pandas. +_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[ + AggFuncTypeBase, _SnowparkPandasAggregation +] = MappingProxyType( + { + "count": _SnowparkPandasAggregation( + axis_0_aggregation=count, + axis_1_aggregation_skipna=_columns_count, + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("mean", np.mean), + _SnowparkPandasAggregation( + axis_0_aggregation=mean, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("min", np.min), + _SnowparkPandasAggregation( + axis_0_aggregation=min_, + axis_1_aggregation_keepna=least, + axis_1_aggregation_skipna=_columns_coalescing_min, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("max", np.max), + _SnowparkPandasAggregation( + axis_0_aggregation=max_, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=_columns_coalescing_max, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("sum", np.sum), + _SnowparkPandasAggregation( + axis_0_aggregation=sum_, + # IMPORTANT: count and sum use python builtin sum to invoke + # __add__ on each column rather than Snowpark sum_, since + # Snowpark sum_ gets the sum of all rows within a single column. + axis_1_aggregation_keepna=lambda *cols: sum(cols), + axis_1_aggregation_skipna=_columns_coalescing_sum, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("median", np.median), + _SnowparkPandasAggregation( + axis_0_aggregation=median, + preserves_snowpark_pandas_types=True, + ), + ), + "idxmax": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "idxmin": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "skew": _SnowparkPandasAggregation( + axis_0_aggregation=skew, + preserves_snowpark_pandas_types=True, + ), + "all": _SnowparkPandasAggregation( + # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("booland_agg")(col(c)), pandas_lit(True) + ), + preserves_snowpark_pandas_types=False, + ), + "any": _SnowparkPandasAggregation( + # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("boolor_agg")(col(c)), pandas_lit(False) + ), + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("std", np.std), + _SnowparkPandasAggregation( + axis_0_aggregation=stddev, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("var", np.var), + _SnowparkPandasAggregation( + axis_0_aggregation=variance, + # variance units are the square of the input column units, so + # variance does not preserve types. + preserves_snowpark_pandas_types=False, + ), + ), + "array_agg": _SnowparkPandasAggregation( + axis_0_aggregation=array_agg, + preserves_snowpark_pandas_types=False, + ), + "quantile": _SnowparkPandasAggregation( + axis_0_aggregation=column_quantile, + preserves_snowpark_pandas_types=True, + ), + "nunique": _SnowparkPandasAggregation( + axis_0_aggregation=count_distinct, + preserves_snowpark_pandas_types=False, + ), + } +) class AggregateColumnOpParameters(NamedTuple): @@ -462,7 +575,7 @@ class AggregateColumnOpParameters(NamedTuple): agg_snowflake_quoted_identifier: str # the snowflake aggregation function to apply on the column - snowflake_agg_func: Callable + snowflake_agg_func: SnowflakeAggFunc # the columns specifying the order of rows in the column. This is only # relevant for aggregations that depend on row order, e.g. summing a string @@ -471,88 +584,108 @@ class AggregateColumnOpParameters(NamedTuple): def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: - return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP + return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION def get_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0 -) -> Optional[Callable]: + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] +) -> Optional[SnowflakeAggFunc]: """ Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. If no corresponding snowflake aggregation function can be found, return None. """ - if axis == 0: - snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func) - if snowflake_agg_func == stddev or snowflake_agg_func == variance: - # for aggregation function std and var, we only support ddof = 0 or ddof = 1. - # when ddof is 1, std is mapped to stddev, var is mapped to variance - # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop - # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 - ddof = agg_kwargs.get("ddof", 1) - if ddof != 1 and ddof != 0: - return None - if ddof == 0: - return stddev_pop if snowflake_agg_func == stddev else var_pop - elif snowflake_agg_func == column_quantile: - interpolation = agg_kwargs.get("interpolation", "linear") - q = agg_kwargs.get("q", 0.5) - if interpolation not in ("linear", "nearest"): - return None - if not is_scalar(q): - # SNOW-1062878 Because list-like q would return multiple rows, calling quantile - # through the aggregate frontend in this manner is unsupported. - return None - return lambda col: column_quantile(col, interpolation, q) - elif agg_func in ("all", "any"): - # If there are no rows in the input frame, the function will also return NULL, which should - # instead by TRUE for "all" and FALSE for "any". - # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name - # as a string literal. - # The generated SQL expression for "all" is - # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE) - # The expression for "any" is - # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE) - default_value = bool(agg_func == "all") - return lambda col: builtin("ifnull")( - # mypy refuses to acknowledge snowflake_agg_func is non-NULL here - snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc] - pandas_lit(default_value), + if axis == 1: + return _generate_rowwise_aggregation_function(agg_func, agg_kwargs) + + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + + if snowpark_pandas_aggregation is None: + # We don't have any implementation at all for this aggregation. + return None + + snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation + + if snowpark_aggregation is None: + # We don't have an implementation on axis=0 for this aggregation. + return None + + # Rewrite some aggregations according to `agg_kwargs.` + if snowpark_aggregation == stddev or snowpark_aggregation == variance: + # for aggregation function std and var, we only support ddof = 0 or ddof = 1. + # when ddof is 1, std is mapped to stddev, var is mapped to variance + # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop + # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 + ddof = agg_kwargs.get("ddof", 1) + if ddof != 1 and ddof != 0: + return None + if ddof == 0: + snowpark_aggregation = ( + stddev_pop if snowpark_aggregation == stddev else var_pop ) - else: - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + elif snowpark_aggregation == column_quantile: + interpolation = agg_kwargs.get("interpolation", "linear") + q = agg_kwargs.get("q", 0.5) + if interpolation not in ("linear", "nearest"): + return None + if not is_scalar(q): + # SNOW-1062878 Because list-like q would return multiple rows, calling quantile + # through the aggregate frontend in this manner is unsupported. + return None + + def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: + return column_quantile(col, interpolation, q) - return snowflake_agg_func + assert ( + snowpark_aggregation is not None + ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." + return SnowflakeAggFunc( + snowpark_aggregation=snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def generate_rowwise_aggregation_function( +def _generate_rowwise_aggregation_function( agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any] -) -> Optional[Callable]: +) -> Optional[SnowflakeAggFunc]: """ Get a callable taking *arg columns to apply for an aggregation. Unlike get_snowflake_agg_func, this function may return a wrapped composition of Snowflake builtin functions depending on the values of the specified kwargs. """ - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) - if not agg_kwargs.get("skipna", True): - snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get( - agg_func, snowflake_agg_func - ) + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + if snowpark_pandas_aggregation is None: + return None + snowpark_aggregation = ( + snowpark_pandas_aggregation.axis_1_aggregation_skipna + if agg_kwargs.get("skipna", True) + else snowpark_pandas_aggregation.axis_1_aggregation_keepna + ) + if snowpark_aggregation is None: + return None min_count = agg_kwargs.get("min_count", 0) if min_count > 0: + original_aggregation = snowpark_aggregation + # Create a case statement to check if the number of non-null values exceeds min_count # when min_count > 0, if the number of not NULL values is < min_count, return NULL. - def agg_func_wrapper(fn: Callable) -> Callable: - return lambda *cols: when( - _columns_count(*cols) < min_count, pandas_lit(None) - ).otherwise(fn(*cols)) + def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: + return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise( + original_aggregation(*cols) + ) - return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func) - return snowflake_agg_func + return SnowflakeAggFunc( + snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def is_supported_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int +def _is_supported_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ check if the aggregation function is supported with snowflake. Current supported @@ -566,12 +699,14 @@ def is_supported_snowflake_agg_func( is_valid: bool. Whether it is valid to implement with snowflake or not. """ if isinstance(agg_func, tuple) and len(agg_func) == 2: + # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, + # take the second part of the named aggregation. agg_func = agg_func[0] return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None -def are_all_agg_funcs_supported_by_snowflake( - agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int +def _are_all_agg_funcs_supported_by_snowflake( + agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ Check if all aggregation functions in the given list are snowflake supported @@ -582,14 +717,14 @@ def are_all_agg_funcs_supported_by_snowflake( return False. """ return all( - is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs ) def check_is_aggregation_supported_in_snowflake( agg_func: AggFuncType, agg_kwargs: dict[str, Any], - axis: int, + axis: Literal[0, 1], ) -> bool: """ check if distributed implementation with snowflake is available for the aggregation @@ -608,18 +743,18 @@ def check_is_aggregation_supported_in_snowflake( if is_dict_like(agg_func): return all( ( - are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) if is_list_like(value) and not is_named_tuple(value) - else is_supported_snowflake_agg_func(value, agg_kwargs, axis) + else _is_supported_snowflake_agg_func(value, agg_kwargs, axis) ) for value in agg_func.values() ) elif is_list_like(agg_func): - return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) - return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) + return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) -def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: +def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: """ Is the given snowflake aggregation function needs to be applied on the numeric column. """ @@ -697,7 +832,7 @@ def drop_non_numeric_data_columns( ) -def generate_aggregation_column( +def _generate_aggregation_column( agg_column_op_params: AggregateColumnOpParameters, agg_kwargs: dict[str, Any], is_groupby_agg: bool, @@ -721,8 +856,14 @@ def generate_aggregation_column( SnowparkColumn after the aggregation function. The column is also aliased back to the original name """ snowpark_column = agg_column_op_params.snowflake_quoted_identifier - snowflake_agg_func = agg_column_op_params.snowflake_agg_func - if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( + snowflake_agg_func = agg_column_op_params.snowflake_agg_func.snowpark_aggregation + + if snowflake_agg_func in (variance, var_pop) and isinstance( + agg_column_op_params.data_type, TimedeltaType + ): + raise TypeError("timedelta64 type does not support var operations") + + if _is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( agg_column_op_params.data_type, BooleanType ): # if the column is a boolean column and the aggregation function requires numeric values, @@ -753,7 +894,7 @@ def generate_aggregation_column( # note that we always assume keepna for array_agg. TODO(SNOW-1040398): # make keepna treatment consistent across array_agg and other # aggregation methods. - agg_snowpark_column = array_agg_keepna( + agg_snowpark_column = _array_agg_keepna( snowpark_column, ordering_columns=agg_column_op_params.ordering_columns ) elif ( @@ -825,6 +966,19 @@ def generate_aggregation_column( ), f"No case expression is constructed with skipna({skipna}), min_count({min_count})" agg_snowpark_column = case_expr.otherwise(agg_snowpark_column) + if ( + isinstance(agg_column_op_params.data_type, TimedeltaType) + and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types + ): + # timedelta aggregations that produce timedelta results might produce + # a decimal type in snowflake, e.g. + # pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in + # Snowflake. We truncate the decimal part of the result, as pandas + # does. + agg_snowpark_column = cast( + trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type + ) + # rename the column to agg_column_quoted_identifier agg_snowpark_column = agg_snowpark_column.as_( agg_column_op_params.agg_snowflake_quoted_identifier @@ -857,7 +1011,7 @@ def aggregate_with_ordered_dataframe( is_groupby_agg = groupby_columns is not None agg_list: list[SnowparkColumn] = [ - generate_aggregation_column( + _generate_aggregation_column( agg_column_op_params=agg_col_op, agg_kwargs=agg_kwargs, is_groupby_agg=is_groupby_agg, @@ -973,7 +1127,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: ) -def generate_pandas_labels_for_agg_result_columns( +def _generate_pandas_labels_for_agg_result_columns( pandas_label: Hashable, num_levels: int, agg_func_list: list[AggFuncInfo], @@ -1102,7 +1256,7 @@ def generate_column_agg_info( ) # generate the pandas label and quoted identifier for the result aggregation columns, one # for each aggregation function to apply. - agg_col_labels = generate_pandas_labels_for_agg_result_columns( + agg_col_labels = _generate_pandas_labels_for_agg_result_columns( pandas_label_to_identifier.pandas_label, num_levels, agg_func_list, # type: ignore[arg-type] diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py index 882dc79d2a8..4eaf98d9b29 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -189,8 +189,6 @@ def compute_bin_indices( values_frame, cuts_frame, how="asof", - left_on=[], - right_on=[], left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0], right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0], match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index c2c224e404c..6207bd2399a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame( key, count_frame, "cross", - left_on=[], - right_on=[], inherit_join_index=InheritJoinIndex.FROM_LEFT, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 79f063b9ece..d07211dbcf5 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple): result_column_mapper: JoinOrAlignResultColumnMapper +def assert_snowpark_pandas_types_match( + left: InternalFrame, + right: InternalFrame, + left_join_identifiers: list[str], + right_join_identifiers: list[str], +) -> None: + """ + If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised. + + Args: + left: An internal frame to use on left side of join. + right: An internal frame to use on right side of join. + left_join_identifiers: List of snowflake identifiers to check types from 'left' frame. + right_join_identifiers: List of snowflake identifiers to check types from 'right' frame. + left_identifiers and right_identifiers must be lists of equal length. + + Returns: None + + Raises: ValueError + """ + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in left_join_identifiers + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in right_join_identifiers + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_join_identifiers[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = ( + rt + if rt is not None + else right.get_snowflake_type(right_join_identifiers[i]) + ) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + def join( left: InternalFrame, right: InternalFrame, how: JoinTypeLit, - left_on: list[str], - right_on: list[str], + left_on: Optional[list[str]] = None, + right_on: Optional[list[str]] = None, left_match_col: Optional[str] = None, right_match_col: Optional[str] = None, match_comparator: Optional[MatchComparator] = None, @@ -161,40 +206,48 @@ def join( include mapping for index + data columns, ordering columns and row position column if exists. """ - assert len(left_on) == len( - right_on - ), "left_on and right_on must be of same length or both be None" - if join_key_coalesce_config is not None: - assert len(join_key_coalesce_config) == len( - left_on - ), "join_key_coalesce_config must be of same length as left_on and right_on" assert how in get_args( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" - def assert_snowpark_pandas_types_match() -> None: - """If Snowpark pandas types do not match, then a ValueError will be raised.""" - left_types = [ - left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in left_on - ] - right_types = [ - right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in right_on - ] - for i, (lt, rt) in enumerate(zip(left_types, right_types)): - if lt != rt: - left_on_id = left_on[i] - idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) - key = left.data_column_pandas_labels[idx] - lt = lt if lt is not None else left.get_snowflake_type(left_on_id) - rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) - raise ValueError( - f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " - f"If you wish to proceed you should use pd.concat" - ) + left_on = left_on or [] + right_on = right_on or [] + assert len(left_on) == len( + right_on + ), "left_on and right_on must be of same length or both be None" - assert_snowpark_pandas_types_match() + if how == "asof": + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + left_join_key = [left_match_col] + right_join_key = [right_match_col] + left_join_key.extend(left_on) + right_join_key.extend(right_on) + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key" + else: + left_join_key = left_on + right_join_key = right_on + assert ( + left_match_col is None + and right_match_col is None + and match_comparator is None + ), f"match condition should not be provided for {how} join" + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "join_key_coalesce_config must be of same length as left_on and right_on" + + assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key) # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None: match_comparator=match_comparator, how=how, ) - return _create_internal_frame_with_join_or_align_result( joined_ordered_dataframe, left, right, how, - left_on, - right_on, + left_join_key, + right_join_key, sort, join_key_coalesce_config, inherit_join_index, @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None: ) elif self._how == "right": ordering_column_identifiers = mapped_right_on + elif self._how == "asof": + # Order only by the left match_condition column + ordering_column_identifiers = [mapped_left_on[0]] else: # left join, inner join, left align, coalesce align ordering_column_identifiers = mapped_left_on diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index f7ae87c2a5d..91537d98e30 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -1197,22 +1197,29 @@ def join( # get the new mapped right on identifier right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols] - # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' - on = None - for left_col, right_col in zip(left_on_cols, right_on_cols): - eq = Column(left_col).equal_null(Column(right_col)) - on = eq if on is None else on & eq - if how == "asof": - assert left_match_col, "left_match_col was not provided to ASOF Join" + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" left_match_col = Column(left_match_col) # Get the new mapped right match condition identifier - assert right_match_col, "right_match_col was not provided to ASOF Join" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" right_match_col = Column(right_identifiers_rename_map[right_match_col]) # ASOF Join requires the use of match_condition - assert match_comparator, "match_comparator was not provided to ASOF Join" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).__eq__(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right=right_snowpark_dataframe_ref.snowpark_dataframe, + on=on, how=how, match_condition=getattr(left_match_col, match_comparator.value)( right_match_col @@ -1224,6 +1231,12 @@ def join( right_snowpark_dataframe_ref.snowpark_dataframe, how=how ) else: + # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).equal_null(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right_snowpark_dataframe_ref.snowpark_dataframe, on, how ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 3bf1062107e..e7a96b49ef1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -520,12 +520,15 @@ def single_pivot_helper( data_column_snowflake_quoted_identifiers: new data column snowflake quoted identifiers this pivot result data_column_pandas_labels: new data column pandas labels for this pivot result """ - snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {}) - if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func): + snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0) + if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func( + snowflake_agg_func.snowpark_aggregation + ): # TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations raise ErrorMessage.not_implemented( f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments." ) + snowpark_aggr_func = snowflake_agg_func.snowpark_aggregation pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair @@ -1231,17 +1234,19 @@ def get_margin_aggregation( Returns: Snowpark column expression for the aggregation function result. """ - resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}) + resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}, axis=0) # This would have been resolved during the original pivot at an early stage. assert resolved_aggfunc is not None, "resolved_aggfunc is None" - aggfunc_expr = resolved_aggfunc(snowflake_quoted_identifier) + aggregation_expression = resolved_aggfunc.snowpark_aggregation( + snowflake_quoted_identifier + ) - if resolved_aggfunc == sum_: - aggfunc_expr = coalesce(aggfunc_expr, pandas_lit(0)) + if resolved_aggfunc.snowpark_aggregation == sum_: + aggregation_expression = coalesce(aggregation_expression, pandas_lit(0)) - return aggfunc_expr + return aggregation_expression def expand_pivot_result_with_pivot_table_margins_no_groupby_columns( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index de83e0429bf..ba8ceedec5e 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -649,8 +649,6 @@ def perform_asof_join_on_frame( left=preserving_frame, right=referenced_frame, how="asof", - left_on=[], - right_on=[], left_match_col=left_timecol_snowflake_quoted_identifier, right_match_col=right_timecol_snowflake_quoted_identifier, match_comparator=( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index d38584c14de..e19a6de37ba 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -567,9 +567,7 @@ def __new__( attrs (Dict[str, Any]): The attributes of the class. Returns: - Union[snowflake.snowpark.modin.pandas.series.Series, - snowflake.snowpark.modin.pandas.dataframe.DataFrame, - snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, + Union[snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy, snowflake.snowpark.modin.pandas.resample.Resampler, snowflake.snowpark.modin.pandas.window.Window, snowflake.snowpark.modin.pandas.window.Rolling]: diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index f8629e664f3..3b714087535 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -525,7 +525,7 @@ def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: The column after conversion to the specified timezone """ if tz is None: - return convert_timezone(pandas_lit("UTC"), column) + return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column)) else: if isinstance(tz, dt.tzinfo): tz_name = tz.tzname(None) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 9f01954ab2c..34a3376fcc1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -1276,7 +1276,7 @@ def check_snowpark_pandas_object_in_arg(arg: Any) -> bool: if check_snowpark_pandas_object_in_arg(v): return True else: - from snowflake.snowpark.modin.pandas import DataFrame, Series + from modin.pandas import DataFrame, Series return isinstance(arg, (DataFrame, Series)) @@ -1519,9 +1519,15 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta: downcast_pandas_df.columns, cached_snowpark_pandas_types ): if snowpark_pandas_type is not None and snowpark_pandas_type == timedelta_t: - downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( - convert_str_to_timedelta - ) + # By default, pandas warns, "A value is trying to be set on a + # copy of a slice from a DataFrame" here because we are + # assigning a column to downcast_pandas_df, which is a copy of + # a slice of pandas_df. We don't care what happens to pandas_df, + # so the warning isn't useful to us. + with native_pd.option_context("mode.chained_assignment", None): + downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply( + convert_str_to_timedelta + ) # Step 7. postprocessing for return types if index_only: diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b5022bff46b..a3981379aaf 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -149,8 +149,6 @@ ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, - GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, - GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, AggFuncInfo, AggFuncWithLabel, AggregateColumnOpParameters, @@ -161,7 +159,6 @@ convert_agg_func_arg_to_col_agg_func_map, drop_non_numeric_data_columns, generate_column_agg_info, - generate_rowwise_aggregation_function, get_agg_func_to_col_map, get_pandas_aggr_func_name, get_snowflake_agg_func, @@ -2040,8 +2037,8 @@ def binary_op( # Native pandas does not support binary operations between a Series and a list-like object. from modin.pandas import Series + from modin.pandas.dataframe import DataFrame - from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.utils import is_scalar # fail explicitly for unsupported scenarios @@ -3556,42 +3553,22 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) - # Get the column aggregation functions used to check if the function - # preserves Snowpark pandas types. - agg_col_funcs = [] - for _, func in column_to_agg_func.items(): - if is_list_like(func) and not is_named_tuple(func): - for fn in func: - agg_col_funcs.append(fn.func) - else: - agg_col_funcs.append(func.func) # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] new_data_column_quoted_identifiers = [] new_data_column_snowpark_pandas_types = [] - for i in range(len(agg_col_ops)): - col_agg_op = agg_col_ops[i] - col_agg_func = agg_col_funcs[i] - new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label) + for agg_col_op in agg_col_ops: + new_data_column_pandas_labels.append(agg_col_op.agg_pandas_label) new_data_column_quoted_identifiers.append( - col_agg_op.agg_snowflake_quoted_identifier + agg_col_op.agg_snowflake_quoted_identifier + ) + new_data_column_snowpark_pandas_types.append( + agg_col_op.data_type + if isinstance(agg_col_op.data_type, SnowparkPandasType) + and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types + else None ) - if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: - new_data_column_snowpark_pandas_types.append( - col_agg_op.data_type - if isinstance(col_agg_op.data_type, SnowparkPandasType) - else None - ) - elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: - # In the case where the aggregation overrides the type of the output data column - # (e.g. any always returns boolean data columns), set the output Snowpark pandas type - # of the given column to None - new_data_column_snowpark_pandas_types.append(None) # type: ignore - else: - self._raise_not_implemented_error_for_timedelta() - new_data_column_snowpark_pandas_types = None # type: ignore - # The ordering of the named aggregations is changed by us when we process # the agg_kwargs into the func dict (named aggregations on the same # column are moved to be contiguous, see groupby.py::aggregate for an @@ -3644,7 +3621,7 @@ def convert_func_to_agg_func_info( ), agg_pandas_label=None, agg_snowflake_quoted_identifier=row_position_quoted_identifier, - snowflake_agg_func=min_, + snowflake_agg_func=get_snowflake_agg_func("min", agg_kwargs={}, axis=0), ordering_columns=internal_frame.ordering_columns, ) agg_col_ops.append(row_position_agg_column_op) @@ -5657,8 +5634,6 @@ def agg( args: the arguments passed for the aggregation kwargs: keyword arguments passed for the aggregation function. """ - self._raise_not_implemented_error_for_timedelta() - numeric_only = kwargs.get("numeric_only", False) # Call fallback if the aggregation function passed in the arg is currently not supported # by snowflake engine. @@ -5704,6 +5679,11 @@ def agg( not is_list_like(value) for value in func.values() ) if axis == 1: + if any( + isinstance(t, TimedeltaType) + for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values() + ): + ErrorMessage.not_implemented_for_timedelta("agg(axis=1)") if self.is_multiindex(): # TODO SNOW-1010307 fix axis=1 behavior with MultiIndex ErrorMessage.not_implemented( @@ -5761,9 +5741,9 @@ def agg( pandas_column_labels=frame.data_column_pandas_labels, ) if agg_arg in ("idxmin", "idxmax") - else generate_rowwise_aggregation_function(agg_arg, kwargs)( - *(col(c) for c in data_col_identifiers) - ) + else get_snowflake_agg_func( + agg_arg, kwargs, axis=1 + ).snowpark_aggregation(*(col(c) for c in data_col_identifiers)) for agg_arg in agg_args } pandas_labels = list(agg_col_map.keys()) @@ -5883,7 +5863,13 @@ def generate_agg_qc( index_column_snowflake_quoted_identifiers=[ agg_name_col_quoted_identifier ], - data_column_types=None, + data_column_types=[ + col.data_type + if isinstance(col.data_type, SnowparkPandasType) + and col.snowflake_agg_func.preserves_snowpark_pandas_types + else None + for col in col_agg_infos + ], index_column_types=None, ) return SnowflakeQueryCompiler(single_agg_dataframe) @@ -7395,18 +7381,9 @@ def merge_asof( SnowflakeQueryCompiler """ # TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation - if ( - by - or left_by - or right_by - or left_index - or right_index - or tolerance - or suffixes != ("_x", "_y") - ): + if tolerance or suffixes != ("_x", "_y"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method does not currently support parameters " - + "'by', 'left_by', 'right_by', 'left_index', 'right_index', " + "'suffixes', or 'tolerance'" ) if direction not in ("backward", "forward"): @@ -7414,9 +7391,24 @@ def merge_asof( "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ) + if direction == "backward": + match_comparator = ( + MatchComparator.GREATER_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.GREATER_THAN + ) + else: + match_comparator = ( + MatchComparator.LESS_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.LESS_THAN + ) + left_frame = self._modin_frame right_frame = right._modin_frame - left_keys, right_keys = join_utils.get_join_keys( + # Get the left and right matching key and quoted identifier corresponding to the match_condition + # There will only be matching key/identifier for each table as there is only a single match condition + left_match_keys, right_match_keys = join_utils.get_join_keys( left=left_frame, right=right_frame, on=on, @@ -7425,42 +7417,62 @@ def merge_asof( left_index=left_index, right_index=right_index, ) - left_match_col = ( + left_match_identifier = ( left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - left_keys + left_match_keys )[0][0] ) - right_match_col = ( + right_match_identifier = ( right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - right_keys + right_match_keys )[0][0] ) - - if direction == "backward": - match_comparator = ( - MatchComparator.GREATER_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.GREATER_THAN + coalesce_config = join_utils.get_coalesce_config( + left_keys=left_match_keys, + right_keys=right_match_keys, + external_join_keys=[], + ) + + # Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition + if by or (left_by and right_by): + left_on_keys, right_on_keys = join_utils.get_join_keys( + left=left_frame, + right=right_frame, + on=by, + left_on=left_by, + right_on=right_by, + ) + left_on_identifiers = [ + ids[0] + for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + left_on_keys + ) + ] + right_on_identifiers = [ + ids[0] + for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + right_on_keys + ) + ] + coalesce_config.extend( + join_utils.get_coalesce_config( + left_keys=left_on_keys, + right_keys=right_on_keys, + external_join_keys=[], + ) ) else: - match_comparator = ( - MatchComparator.LESS_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.LESS_THAN - ) - - coalesce_config = join_utils.get_coalesce_config( - left_keys=left_keys, right_keys=right_keys, external_join_keys=[] - ) + left_on_identifiers = [] + right_on_identifiers = [] joined_frame, _ = join_utils.join( left=left_frame, right=right_frame, + left_on=left_on_identifiers, + right_on=right_on_identifiers, how="asof", - left_on=[left_match_col], - right_on=[right_match_col], - left_match_col=left_match_col, - right_match_col=right_match_col, + left_match_col=left_match_identifier, + right_match_col=right_match_identifier, match_comparator=match_comparator, join_key_coalesce_config=coalesce_config, sort=True, @@ -9129,7 +9141,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler": SnowflakeQueryCompiler Transposed new QueryCompiler object. """ - self._raise_not_implemented_error_for_timedelta() + if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1: + # In this case, transpose may lose types. + self._raise_not_implemented_error_for_timedelta() frame = self._modin_frame @@ -12513,8 +12527,6 @@ def _quantiles_single_col( column would allow us to create an accurate row position column, but would require a potentially expensive JOIN operator afterwards to apply the correct index labels. """ - self._raise_not_implemented_error_for_timedelta() - assert len(self._modin_frame.data_column_pandas_labels) == 1 if index is not None: @@ -12579,7 +12591,7 @@ def _quantiles_single_col( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_identifier], - data_column_types=None, + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, index_column_types=None, ) # We cannot call astype() directly to convert an index column, so we replicate @@ -13613,6 +13625,16 @@ def _window_agg( } ).frame else: + snowflake_agg_func = get_snowflake_agg_func(agg_func, agg_kwargs, axis=0) + if snowflake_agg_func is None: + # We don't have test coverage for this situation because we + # test individual rolling and expanding methods we've implemented, + # like rolling_sum(), but other rolling methods raise + # NotImplementedError immediately. We also don't support rolling + # agg(), which might take us here. + ErrorMessage.not_implemented( # pragma: no cover + f"Window aggregation does not support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)}" + ) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { # If aggregation is count use count on row_position_quoted_identifier @@ -13623,7 +13645,7 @@ def _window_agg( if agg_func == "count" else count(col(quoted_identifier)).over(window_expr) >= min_periods, - get_snowflake_agg_func(agg_func, agg_kwargs)( + snowflake_agg_func.snowpark_aggregation( # Expanding is cumulative so replace NULL with 0 for sum aggregation builtin("zeroifnull")(col(quoted_identifier)) if window_func == WindowFunction.EXPANDING @@ -14577,8 +14599,6 @@ def idxmax( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -14603,8 +14623,6 @@ def idxmin( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -16668,6 +16686,7 @@ def dt_tz_localize( tz: Union[str, tzinfo], ambiguous: str = "raise", nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Localize tz-naive to tz-aware. @@ -16675,39 +16694,50 @@ def dt_tz_localize( tz : str, pytz.timezone, optional ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise" nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise" + include_index: Whether to include the index columns in the operation. Returns: BaseQueryCompiler New QueryCompiler containing values with localized time zone. """ + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if not include_index: + method_name = "Series.dt.tz_localize" + else: + assert is_datetime64_any_dtype(dtype), "column must be datetime" + method_name = "DatetimeIndex.tz_localize" + if not isinstance(ambiguous, str) or ambiguous != "raise": - ErrorMessage.parameter_not_implemented_error( - "ambiguous", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if not isinstance(nonexistent, str) or nonexistent != "raise": - ErrorMessage.parameter_not_implemented_error( - "nonexistent", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_localize_column(column, tz) + lambda column: tz_localize_column(column, tz), + include_index, ) ) - def dt_tz_convert(self, tz: Union[str, tzinfo]) -> "SnowflakeQueryCompiler": + def dt_tz_convert( + self, + tz: Union[str, tzinfo], + include_index: bool = False, + ) -> "SnowflakeQueryCompiler": """ Convert time-series data to the specified time zone. Args: tz : str, pytz.timezone + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler containing values with converted time zone. """ return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_convert_column(column, tz) + lambda column: tz_convert_column(column, tz), + include_index, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/base.py b/src/snowflake/snowpark/modin/plugin/docstrings/base.py index af50e0379dd..4044f7b675f 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/base.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/base.py @@ -2832,6 +2832,7 @@ def shift(): """ Implement shared functionality between DataFrame and Series for shift. axis argument is only relevant for Dataframe, and should be 0 for Series. + Args: periods : int | Sequence[int] Number of periods to shift. Can be positive or negative. If an iterable of ints, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 6223e9dd273..f7e93e6c2df 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -1749,7 +1749,7 @@ def info(): ... 'COL2': ['A', 'B', 'C']}) >>> df.info() # doctest: +NORMALIZE_WHITESPACE -")[0] + + f"
{row_count} rows × {col_count} columns
\n" + ) + else: + return result + + +# Upstream modin just uses `to_datetime` rather than `dataframe_to_datetime` on the query compiler. +@register_dataframe_accessor("_to_datetime") +def _to_datetime(self, **kwargs): + """ + Convert `self` to datetime. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_datetime` invocation. + + Returns + ------- + Series of datetime64 dtype + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._reduce_dimension( + query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) + ) + + +# Snowpark pandas has the extra `statement_params` argument. +@register_dataframe_accessor("_to_pandas") +def _to_pandas( + self, + *, + statement_params: dict[str, str] | None = None, + **kwargs: Any, +) -> native_pd.DataFrame: + """ + Convert Snowpark pandas DataFrame to pandas DataFrame + + Args: + statement_params: Dictionary of statement level parameters to be set while executing this action. + + Returns: + pandas DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.to_pandas(statement_params=statement_params, **kwargs) + + +# Snowpark pandas does more validation and error checking than upstream Modin, and uses different +# helper methods for dispatch. +@register_dataframe_accessor("__setitem__") +def __setitem__(self, key: Any, value: Any): + """ + Set attribute `value` identified by `key`. + + Args: + key: Key to set + value: Value to set + + Note: + In the case where value is any list like or array, pandas checks the array length against the number of rows + of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw + a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if + the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use + enlargement filling with the last value in the array. + + Returns: + None + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + key = apply_if_callable(key, self) + if isinstance(key, DataFrame) or ( + isinstance(key, np.ndarray) and len(key.shape) == 2 + ): + # This case uses mask's codepath to perform the set, but + # we need to duplicate the code here since we are passing + # an additional kwarg `cond_fillna_with_true` to the QC here. + # We need this additional kwarg, since if df.shape + # and key.shape do not align (i.e. df has more rows), + # mask's codepath would mask the additional rows in df + # while for setitem, we need to keep the original values. + if not isinstance(key, DataFrame): + if key.dtype != bool: + raise TypeError( + "Must pass DataFrame or 2-d ndarray with boolean values only" + ) + key = DataFrame(key) + key._query_compiler._shape_hint = "array" + + if value is not None: + value = apply_if_callable(value, self) + + if isinstance(value, np.ndarray): + value = DataFrame(value) + value._query_compiler._shape_hint = "array" + elif isinstance(value, pd.Series): + # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this + # error instead, since it is more descriptive. + raise ValueError( + "setitem with a 2D key does not support Series values." + ) + + if isinstance(value, BasePandasDataset): + value = value._query_compiler + + query_compiler = self._query_compiler.mask( + cond=key._query_compiler, + other=value, + axis=None, + level=None, + cond_fillna_with_true=True, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace=True) + + # Error Checking: + if (isinstance(key, pd.Series) or is_list_like(key)) and (isinstance(value, range)): + raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) + elif isinstance(value, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) + + # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column + # key. + index, columns = slice(None), key + index_is_bool_indexer = False + if isinstance(key, slice): + if is_integer(key.start) and is_integer(key.stop): + # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as + # df.iloc[1:2, :] = val + self.iloc[key] = value + return + index, columns = key, slice(None) + elif isinstance(key, pd.Series): + if is_bool_dtype(key.dtype): + index, columns = key, slice(None) + index_is_bool_indexer = True + elif is_bool_indexer(key): + index, columns = pd.Series(key), slice(None) + index_is_bool_indexer = True + + # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case + # we have to explicitly set matching_item_columns_by_label to False for setitem. + index = index._query_compiler if isinstance(index, BasePandasDataset) else index + columns = ( + columns._query_compiler if isinstance(columns, BasePandasDataset) else columns + ) + from snowflake.snowpark.modin.pandas.indexing import is_2d_array + + matching_item_rows_by_label = not is_2d_array(value) + if is_2d_array(value): + value = DataFrame(value) + item = value._query_compiler if isinstance(value, BasePandasDataset) else value + new_qc = self._query_compiler.set_2d_labels( + index, + columns, + item, + # setitem always matches item by position + matching_item_columns_by_label=False, + matching_item_rows_by_label=matching_item_rows_by_label, + index_is_bool_indexer=index_is_bool_indexer, + # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling + # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the + # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have + # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns + # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", + # "X", "X". + deduplicate_columns=True, + ) + return self._update_inplace(new_query_compiler=new_qc) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index df136af1a34..38edb9f7bee 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -960,7 +960,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex: DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None) """ - @datetime_index_not_implemented() def tz_convert(self, tz) -> DatetimeIndex: """ Convert tz-aware Datetime Array/Index from one time zone to another. @@ -1025,8 +1024,14 @@ def tz_convert(self, tz) -> DatetimeIndex: '2014-08-01 09:00:00'], dtype='datetime64[ns]', freq='h') """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_convert( + tz, + include_index=True, + ) + ) - @datetime_index_not_implemented() def tz_localize( self, tz, @@ -1104,21 +1109,29 @@ def tz_localize( Localize DatetimeIndex in US/Eastern time zone: - >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP - >>> tz_aware # doctest: +SKIP - DatetimeIndex(['2018-03-01 09:00:00-05:00', - '2018-03-02 09:00:00-05:00', + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') + >>> tz_aware + DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00', '2018-03-03 09:00:00-05:00'], - dtype='datetime64[ns, US/Eastern]', freq=None) + dtype='datetime64[ns, UTC-05:00]', freq=None) With the ``tz=None``, we can remove the time zone information while keeping the local time (not converted to UTC): - >>> tz_aware.tz_localize(None) # doctest: +SKIP + >>> tz_aware.tz_localize(None) DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', '2018-03-03 09:00:00'], dtype='datetime64[ns]', freq=None) """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_localize( + tz, + ambiguous, + nonexistent, + include_index=True, + ) + ) def round( self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 12710224de7..b25bb481dc0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -30,6 +30,7 @@ import modin import numpy as np import pandas as native_pd +from modin.pandas import DataFrame, Series from modin.pandas.base import BasePandasDataset from pandas import get_option from pandas._libs import lib @@ -49,7 +50,6 @@ ) from pandas.core.dtypes.inference import is_hashable -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin @@ -71,6 +71,35 @@ } +class IndexParent: + def __init__(self, parent: DataFrame | Series) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + + Parameters + ---------- + parent : DataFrame or Series + The parent object that the Index is a part of. + """ + assert isinstance(parent, (DataFrame, Series)) + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the query compiler associated with the parent is + different from the original query compiler recorded, i.e., an inplace update has been applied to the parent. + """ + if self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,7 +164,7 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. + # `_parent` keeps track of the parent object that this Index is a part of. index._parent = None return index @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any: ErrorMessage.not_implemented(f"Index.{key} is not yet implemented") raise err + def _set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = IndexParent(parent) + def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) @@ -408,12 +448,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - @property def values(self) -> ArrayLike: """ @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None: if not is_hashable(value): raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) + # Update the name of the parent's index only if an inplace update is performed on + # the parent object, i.e., the parent's current query compiler matches the originally + # recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None: if isinstance(values, Index): values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) + # Update the name of the parent's index only if the parent's current query compiler + # matches the recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py index c5f9e4f6cee..43f9603cfb4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py @@ -9,10 +9,10 @@ import inspect from typing import Any, Iterable, Literal, Optional, Union +from modin.pandas import DataFrame, Series from pandas._typing import IndexLabel from snowflake.snowpark import DataFrame as SnowparkDataFrame -from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py index dea98bbb0d3..6d6fb4cd0bd 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py @@ -15,6 +15,7 @@ ) import pandas as native_pd +from modin.pandas import DataFrame from pandas._libs.lib import NoDefault, no_default from pandas._typing import ( CSVEngine, @@ -26,7 +27,6 @@ ) import snowflake.snowpark.modin.pandas as pd -from snowflake.snowpark.modin.pandas import DataFrame from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_standalone_function_decorator, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py index f7bba4c743a..5b245bfdab4 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py @@ -181,7 +181,6 @@ def to_pandas( See Also: - :func:`to_pandas