diff --git a/CHANGELOG.md b/CHANGELOG.md
index fea42391259..daedfe34659 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,12 +1,36 @@
# Release History
-## 1.22.0 (TBD)
+## 1.23.0 (TBD)
+
+### Snowpark pandas API Updates
+
+#### Improvements
+
+- Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type.
+
+#### New Features
+
+- Added support for `TimedeltaIndex.mean` method.
+- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`.
+- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.
+
+#### Bug Fixes
+
+- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`.
+- Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns.
+
+
+## 1.22.1 (2024-09-11)
+This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content.
+
+
+## 1.22.0 (2024-09-10)
### Snowpark Python API Updates
### New Features
-- Added following new functions in `snowflake.snowpark.functions`:
+- Added the following new functions in `snowflake.snowpark.functions`:
- `array_remove`
- `ln`
@@ -46,14 +70,14 @@
- Fixed a bug in `session.read.csv` that caused an error when setting `PARSE_HEADER = True` in an externally defined file format.
- Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries.
- Fixed a bug in `session.get_session_stage` that referenced a non-existing stage after switching database or schema.
-- Fixed a bug where calling `DataFrame.to_snowpark_pandas_dataframe` without explicitly initializing the Snowpark pandas plugin caused an error.
+- Fixed a bug where calling `DataFrame.to_snowpark_pandas` without explicitly initializing the Snowpark pandas plugin caused an error.
- Fixed a bug where using the `explode` function in dynamic table creation caused a SQL compilation error due to improper boolean type casting on the `outer` parameter.
### Snowpark Local Testing Updates
#### New Features
-- Added support for type coercion when passing columns as input to udf calls
+- Added support for type coercion when passing columns as input to UDF calls.
- Added support for `Index.identical`.
#### Bug Fixes
@@ -105,6 +129,9 @@
- Added support for creating a `DatetimeIndex` from an `Index` of numeric or string type.
- Added support for string indexing with `Timedelta` objects.
- Added support for `Series.dt.total_seconds` method.
+- Added support for `DataFrame.apply(axis=0)`.
+- Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`.
+- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`.
#### Improvements
@@ -113,9 +140,11 @@
- Improved `pd.to_datetime` to handle all local input cases.
- Create a lazy index from another lazy index without pulling data to client.
- Raised `NotImplementedError` for Index bitwise operators.
-- Display a clearer error message when `Index.names` is set to a non-like-like object.
+- Display a more clear error message when `Index.names` is set to a non-like-like object.
- Raise a warning whenever MultiIndex values are pulled in locally.
- Improve warning message for `pd.read_snowflake` include the creation reason when temp table creation is triggered.
+- Improve performance for `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index` by avoiding checks require eager evaluation. As a consequence, when the new index that does not match the current `Series`/`DataFrame` object length, a `ValueError` is no longer raised. Instead, when the `Series`/`DataFrame` object is longer than the provided index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. Otherwise, the extra values in the provided index are ignored.
+- Properly raise `NotImplementedError` when ambiguous/nonexistent are non-string in `ceil`/`floor`/`round`.
#### Bug Fixes
@@ -126,10 +155,6 @@
- Fixed a bug where `Series.reindex` and `DataFrame.reindex` did not update the result index's name correctly.
- Fixed a bug where `Series.take` did not error when `axis=1` was specified.
-#### Behavior Change
-
-- When calling `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index`, with a new index that does not match the current length of the `Series`/`DataFrame` object, a `ValueError` is no longer raised. When the `Series`/`DataFrame` object is longer than the new index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. When the `Series`/`DataFrame` object is shorter than the new index, the extra values in the new index are ignored—`Series` and `DataFrame` stay the same length `n`, and use only the first `n` values of the new index.
-
## 1.21.1 (2024-09-05)
diff --git a/docs/source/modin/series.rst b/docs/source/modin/series.rst
index 188bdab344a..4cb8a238b0f 100644
--- a/docs/source/modin/series.rst
+++ b/docs/source/modin/series.rst
@@ -279,6 +279,8 @@ Series
Series.dt.seconds
Series.dt.microseconds
Series.dt.nanoseconds
+ Series.dt.tz_convert
+ Series.dt.tz_localize
.. rubric:: String accessor methods
diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst
index 6bb214e3bd6..54858063e54 100644
--- a/docs/source/modin/supported/dataframe_supported.rst
+++ b/docs/source/modin/supported/dataframe_supported.rst
@@ -84,7 +84,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``any`` | P | | ``N`` for non-integer/boolean types |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
-| ``apply`` | P | | ``N`` if ``axis == 0`` or ``func`` is not callable |
+| ``apply`` | P | | ``N`` if ``func`` is not callable |
| | | | or ``result_type`` is given or ``args`` and |
| | | | ``kwargs`` contain DataFrame or Series |
| | | | ``N`` if ``func`` maps to different column labels. |
@@ -471,8 +471,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``to_xml`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
-| ``transform`` | P | | Only callable and string parameters are supported.|
-| | | | list and dict parameters are not supported. |
+| ``transform`` | P | | ``Y`` if ``func`` is callable. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``transpose`` | P | | See ``T`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst
index 68b1935da96..3afe671aee7 100644
--- a/docs/source/modin/supported/datetime_index_supported.rst
+++ b/docs/source/modin/supported/datetime_index_supported.rst
@@ -82,9 +82,9 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``snap`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
-| ``tz_convert`` | N | | |
+| ``tz_convert`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
-| ``tz_localize`` | N | | |
+| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``round`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst
index 797ef3bbd59..95d9610202b 100644
--- a/docs/source/modin/supported/general_supported.rst
+++ b/docs/source/modin/supported/general_supported.rst
@@ -38,8 +38,7 @@ Data manipulations
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
-| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. |
-| | | , ``left_index``, ``right_index``| |
+| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. |
| | | , ``suffixes``, ``tolerance`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_ordered`` | N | | |
diff --git a/docs/source/modin/supported/series_dt_supported.rst b/docs/source/modin/supported/series_dt_supported.rst
index 3377a3d64e2..68853871ea6 100644
--- a/docs/source/modin/supported/series_dt_supported.rst
+++ b/docs/source/modin/supported/series_dt_supported.rst
@@ -80,9 +80,10 @@ the method in the left column.
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``to_pydatetime`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
-| ``tz_localize`` | N | |
+| ``tz_localize`` | P | ``N`` if `ambiguous` or `nonexistent` are set to a |
+| | | non-default value. |
+-----------------------------+---------------------------------+----------------------------------------------------+
-| ``tz_convert`` | N | |
+| ``tz_convert`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``normalize`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
diff --git a/docs/source/modin/supported/timedelta_index_supported.rst b/docs/source/modin/supported/timedelta_index_supported.rst
index 49dfcb305e4..f7a34c3552c 100644
--- a/docs/source/modin/supported/timedelta_index_supported.rst
+++ b/docs/source/modin/supported/timedelta_index_supported.rst
@@ -44,7 +44,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
| ``ceil`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
-| ``mean`` | N | | |
+| ``mean`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
| ``total_seconds`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
diff --git a/recipe/meta.yaml b/recipe/meta.yaml
index cf1f2c9ad70..9560f4a4408 100644
--- a/recipe/meta.yaml
+++ b/recipe/meta.yaml
@@ -1,5 +1,5 @@
{% set name = "snowflake-snowpark-python" %}
-{% set version = "1.21.1" %}
+{% set version = "1.22.1" %}
package:
name: {{ name|lower }}
diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py
index 76e91b7da92..d8622299ea9 100644
--- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py
+++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py
@@ -956,10 +956,7 @@ def do_resolve_with_resolved_children(
schema_query = schema_query_for_values_statement(logical_plan.output)
if logical_plan.data:
- if (
- len(logical_plan.output) * len(logical_plan.data)
- < ARRAY_BIND_THRESHOLD
- ):
+ if not logical_plan.is_large_local_data:
return self.plan_builder.query(
values_statement(logical_plan.output, logical_plan.data),
logical_plan,
diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py
index 3ed969caada..22591f55e47 100644
--- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py
+++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py
@@ -2,11 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
-from typing import AbstractSet, Optional
+from typing import AbstractSet, List, Optional
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
+ derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
@@ -29,6 +30,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.left, self.right)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.left, self.right)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py
index a2d21db4eb2..a7cb5fd97a9 100644
--- a/src/snowflake/snowpark/_internal/analyzer/expression.py
+++ b/src/snowflake/snowpark/_internal/analyzer/expression.py
@@ -35,6 +35,13 @@
def derive_dependent_columns(
*expressions: "Optional[Expression]",
) -> Optional[AbstractSet[str]]:
+ """
+ Given set of expressions, derive the set of columns that the expressions dependents on.
+
+ Note, the returned dependent columns is a set without duplication. For example, given expression
+ concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has
+ occurred in the given expression twice.
+ """
result = set()
for exp in expressions:
if exp is not None:
@@ -48,6 +55,23 @@ def derive_dependent_columns(
return result
+def derive_dependent_columns_with_duplication(
+ *expressions: "Optional[Expression]",
+) -> List[str]:
+ """
+ Given set of expressions, derive the list of columns that the expression dependents on.
+
+ Note, the returned columns will have duplication if the column occurred more than once in
+ the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result
+ [col1, col1, col2], where col1 occurred twice in the result.
+ """
+ result = []
+ for exp in expressions:
+ if exp is not None:
+ result.extend(exp.dependent_column_names_with_duplication())
+ return result
+
+
class Expression:
"""Consider removing attributes, and adding properties and methods.
A subclass of Expression may have no child, one child, or multiple children.
@@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
# TODO: consider adding it to __init__ or use cached_property.
return COLUMN_DEPENDENCY_EMPTY
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return []
+
@property
def pretty_name(self) -> str:
"""Returns a user-facing string representation of this expression's name.
@@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return COLUMN_DEPENDENCY_DOLLAR
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return list(COLUMN_DEPENDENCY_DOLLAR)
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return self.plan.cumulative_node_complexity
@@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.expressions)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(*self.expressions)
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
@@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.columns, *self.values)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.columns, *self.values)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.IN
@@ -212,6 +248,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return {self.name}
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return [self.name]
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN
@@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
else COLUMN_DEPENDENCY_ALL
)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return (
+ derive_dependent_columns_with_duplication(*self.expressions)
+ if self.expressions
+ else [] # we currently do not handle * dependency
+ )
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1}
@@ -278,6 +324,14 @@ def __hash__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return self._dependent_column_names
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return (
+ []
+ if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL)
+ or (self._dependent_column_names is None)
+ else list(self._dependent_column_names)
+ )
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN
@@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.pattern)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr, self.pattern)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
# expr LIKE pattern
@@ -400,6 +457,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.pattern)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr, self.pattern)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
# expr REG_EXP pattern
@@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
# expr COLLATE collate_spec
@@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
# the literal corresponds to the contribution from self.field
@@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
# the literal corresponds to the contribution from self.field
@@ -510,6 +579,9 @@ def sql(self) -> str:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.children)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(*self.children)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
@@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, *self.order_by_cols)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
# expr WITHIN GROUP (ORDER BY cols)
@@ -549,13 +624,21 @@ def __init__(
self.branches = branches
self.else_value = else_value
- def dependent_column_names(self) -> Optional[AbstractSet[str]]:
+ @property
+ def _child_expressions(self) -> List[Expression]:
exps = []
for exp_tuple in self.branches:
exps.extend(exp_tuple)
if self.else_value is not None:
exps.append(self.else_value)
- return derive_dependent_columns(*exps)
+
+ return exps
+
+ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
+ return derive_dependent_columns(*self._child_expressions)
+
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(*self._child_expressions)
@property
def plan_node_category(self) -> PlanNodeCategory:
@@ -602,6 +685,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.children)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(*self.children)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
@@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.col)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.col)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
@@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.exprs)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(*self.exprs)
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py
index 84cd63fd87d..012940471d0 100644
--- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py
+++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py
@@ -7,6 +7,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
+ derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
@@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.group_by_exprs)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(*self.group_by_exprs)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
@@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
flattened_args = [exp for sublist in self.args for exp in sublist]
return derive_dependent_columns(*flattened_args)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ flattened_args = [exp for sublist in self.args for exp in sublist]
+ return derive_dependent_columns_with_duplication(*flattened_args)
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py
index e3e032cd94b..aa8730dcf7f 100644
--- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py
+++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py
@@ -144,10 +144,27 @@ def __init__(
self.data = data
self.schema_query = schema_query
+ @property
+ def is_large_local_data(self) -> bool:
+ from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD
+
+ return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
+ if self.is_large_local_data:
+ # When the number of literals exceeds the threshold, we generate 3 queries:
+ # 1. create table query
+ # 2. insert into table query
+ # 3. select * from table query
+ # We only consider the complexity from the final select * query since other queries
+ # are built based on it.
+ return {
+ PlanNodeCategory.COLUMN: 1,
+ }
+
+ # If we stay under the threshold, we generate a single query:
# select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm)
- # TODO: use ARRAY_BIND_THRESHOLD
return {
PlanNodeCategory.COLUMN: len(self.output),
PlanNodeCategory.LITERAL: len(self.data) * len(self.output),
diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py
index 1d06f7290a0..82451245e4c 100644
--- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py
+++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py
@@ -2,11 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
-from typing import AbstractSet, Optional, Type
+from typing import AbstractSet, List, Optional, Type
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
+ derive_dependent_columns_with_duplication,
)
@@ -55,3 +56,6 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.child)
+
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.child)
diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py
index e5886e11069..1ae08e8fde2 100644
--- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py
+++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py
@@ -2,12 +2,13 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
-from typing import AbstractSet, Dict, Optional
+from typing import AbstractSet, Dict, List, Optional
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
NamedExpression,
derive_dependent_columns,
+ derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
@@ -36,6 +37,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.child)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.child)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py
index 69db3f265ce..4381c4a2e22 100644
--- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py
+++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py
@@ -7,6 +7,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
+ derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
@@ -71,6 +72,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.lower, self.upper)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.lower, self.upper)
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
@@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
*self.partition_spec, *self.order_spec, self.frame_spec
)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(
+ *self.partition_spec, *self.order_spec, self.frame_spec
+ )
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# partition_spec order_by_spec frame_spec
@@ -138,6 +147,11 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.window_function, self.window_spec)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(
+ self.window_function, self.window_spec
+ )
+
@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.WINDOW
@@ -171,6 +185,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.default)
+ def dependent_column_names_with_duplication(self) -> List[str]:
+ return derive_dependent_columns_with_duplication(self.expr, self.default)
+
@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# for func_name
diff --git a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py
index 836628345aa..8d16383a4ce 100644
--- a/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py
+++ b/src/snowflake/snowpark/_internal/compiler/large_query_breakdown.py
@@ -58,11 +58,6 @@
)
from snowflake.snowpark.session import Session
-# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
-# in Snowflake. This is the limit where we start seeing compilation errors.
-COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000
-COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000
-
_logger = logging.getLogger(__name__)
@@ -123,6 +118,12 @@ def __init__(
self._query_generator = query_generator
self.logical_plans = logical_plans
self._parent_map = defaultdict(set)
+ self.complexity_score_lower_bound = (
+ session.large_query_breakdown_complexity_bounds[0]
+ )
+ self.complexity_score_upper_bound = (
+ session.large_query_breakdown_complexity_bounds[1]
+ )
def apply(self) -> List[LogicalPlan]:
if is_active_transaction(self.session):
@@ -183,13 +184,13 @@ def _try_to_breakdown_plan(self, root: TreeNode) -> List[LogicalPlan]:
complexity_score = get_complexity_score(root.cumulative_node_complexity)
_logger.debug(f"Complexity score for root {type(root)} is: {complexity_score}")
- if complexity_score <= COMPLEXITY_SCORE_UPPER_BOUND:
+ if complexity_score <= self.complexity_score_upper_bound:
# Skip optimization if the complexity score is within the upper bound.
return [root]
plans = []
# TODO: SNOW-1617634 Have a one pass algorithm to find the valid node for partitioning
- while complexity_score > COMPLEXITY_SCORE_UPPER_BOUND:
+ while complexity_score > self.complexity_score_upper_bound:
child = self._find_node_to_breakdown(root)
if child is None:
_logger.debug(
@@ -277,7 +278,9 @@ def _is_node_valid_to_breakdown(self, node: LogicalPlan) -> Tuple[bool, int]:
"""
score = get_complexity_score(node.cumulative_node_complexity)
valid_node = (
- COMPLEXITY_SCORE_LOWER_BOUND < score < COMPLEXITY_SCORE_UPPER_BOUND
+ self.complexity_score_lower_bound
+ < score
+ < self.complexity_score_upper_bound
) and self._is_node_pipeline_breaker(node)
if valid_node:
_logger.debug(
diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py
index bef53f0f389..3e6dba71be4 100644
--- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py
+++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py
@@ -3,8 +3,12 @@
#
import copy
+import time
from typing import Dict, List
+from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
+ get_complexity_score,
+)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
PlanQueryType,
Query,
@@ -17,7 +21,11 @@
from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import (
RepeatedSubqueryElimination,
)
+from snowflake.snowpark._internal.compiler.telemetry_constants import (
+ CompilationStageTelemetryField,
+)
from snowflake.snowpark._internal.compiler.utils import create_query_generator
+from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark.mock._connection import MockServerConnection
@@ -68,24 +76,71 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
if self.should_start_query_compilation():
# preparation for compilation
# 1. make a copy of the original plan
+ start_time = time.time()
+ complexity_score_before_compilation = get_complexity_score(
+ self._plan.cumulative_node_complexity
+ )
logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)]
+ deep_copy_end_time = time.time()
+
# 2. create a code generator with the original plan
query_generator = create_query_generator(self._plan)
- # apply each optimizations if needed
+ # 3. apply each optimizations if needed
+ # CTE optimization
+ cte_start_time = time.time()
if self._plan.session.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
logical_plans = repeated_subquery_eliminator.apply()
+
+ cte_end_time = time.time()
+ complexity_scores_after_cte = [
+ get_complexity_score(logical_plan.cumulative_node_complexity)
+ for logical_plan in logical_plans
+ ]
+
+ # Large query breakdown
if self._plan.session.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
self._plan.session, query_generator, logical_plans
)
logical_plans = large_query_breakdown.apply()
- # do a final pass of code generation
- return query_generator.generate_queries(logical_plans)
+ large_query_breakdown_end_time = time.time()
+ complexity_scores_after_large_query_breakdown = [
+ get_complexity_score(logical_plan.cumulative_node_complexity)
+ for logical_plan in logical_plans
+ ]
+
+ # 4. do a final pass of code generation
+ queries = query_generator.generate_queries(logical_plans)
+
+ # log telemetry data
+ deep_copy_time = deep_copy_end_time - start_time
+ cte_time = cte_end_time - cte_start_time
+ large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time
+ total_time = time.time() - start_time
+ session = self._plan.session
+ summary_value = {
+ TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled,
+ TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled,
+ CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: session.large_query_breakdown_complexity_bounds,
+ CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time,
+ CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time,
+ CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time,
+ CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time,
+ CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation,
+ CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte,
+ CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown,
+ }
+ session._conn._telemetry_client.send_query_compilation_summary_telemetry(
+ session_id=session.session_id,
+ plan_uuid=self._plan.uuid,
+ compilation_stage_summary=summary_value,
+ )
+ return queries
else:
final_plan = self._plan
if self._plan.session.cte_optimization_enabled:
diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py
index 3c1f0d4fc5d..be61a1ac924 100644
--- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py
+++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py
@@ -6,10 +6,28 @@
class CompilationStageTelemetryField(Enum):
+ # types
TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED = (
"snowpark_large_query_breakdown_optimization_skipped"
)
+ TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics"
+ TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS = (
+ "snowpark_large_query_breakdown_update_complexity_bounds"
+ )
+
+ # keys
KEY_REASON = "reason"
+ PLAN_UUID = "plan_uuid"
+ TIME_TAKEN_FOR_COMPILATION = "time_taken_for_compilation_sec"
+ TIME_TAKEN_FOR_DEEP_COPY_PLAN = "time_taken_for_deep_copy_plan_sec"
+ TIME_TAKEN_FOR_CTE_OPTIMIZATION = "time_taken_for_cte_optimization_sec"
+ TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN = "time_taken_for_large_query_breakdown_sec"
+ COMPLEXITY_SCORE_BOUNDS = "complexity_score_bounds"
+ COMPLEXITY_SCORE_BEFORE_COMPILATION = "complexity_score_before_compilation"
+ COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization"
+ COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN = (
+ "complexity_score_after_large_query_breakdown"
+ )
class SkipLargeQueryBreakdownCategory(Enum):
diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py
index 05488398d16..025eb57c540 100644
--- a/src/snowflake/snowpark/_internal/telemetry.py
+++ b/src/snowflake/snowpark/_internal/telemetry.py
@@ -79,6 +79,20 @@ class TelemetryField(Enum):
QUERY_PLAN_HEIGHT = "query_plan_height"
QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes"
QUERY_PLAN_COMPLEXITY = "query_plan_complexity"
+ # temp table cleanup
+ TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup"
+ NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned"
+ NUM_TEMP_TABLES_CREATED = "num_temp_tables_created"
+ TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled"
+ TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = (
+ "snowpark_temp_table_cleanup_abnormal_exception"
+ )
+ TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = (
+ "temp_table_cleanup_abnormal_exception_table_name"
+ )
+ TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = (
+ "temp_table_cleanup_abnormal_exception_message"
+ )
# These DataFrame APIs call other DataFrame APIs
@@ -168,6 +182,11 @@ def wrap(*args, **kwargs):
]._session.sql_simplifier_enabled
try:
api_calls[0][TelemetryField.QUERY_PLAN_HEIGHT.value] = plan.plan_height
+ # The uuid for df._select_statement can be different from df._plan. Since plan
+ # can take both values, we cannot use plan.uuid. We always use df._plan.uuid
+ # to track the queries.
+ uuid = args[0]._plan.uuid
+ api_calls[0][CompilationStageTelemetryField.PLAN_UUID.value] = uuid
api_calls[0][
TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value
] = plan.num_duplicate_nodes
@@ -369,7 +388,7 @@ def send_sql_simplifier_telemetry(
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
- TelemetryField.SQL_SIMPLIFIER_ENABLED.value: True,
+ TelemetryField.SQL_SIMPLIFIER_ENABLED.value: sql_simplifier_enabled,
},
}
self.send(message)
@@ -423,7 +442,25 @@ def send_large_query_breakdown_telemetry(
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
- TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: True,
+ TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: value,
+ },
+ }
+ self.send(message)
+
+ def send_query_compilation_summary_telemetry(
+ self,
+ session_id: int,
+ plan_uuid: str,
+ compilation_stage_summary: Dict[str, Any],
+ ) -> None:
+ message = {
+ **self._create_basic_telemetry_data(
+ CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value
+ ),
+ TelemetryField.KEY_DATA.value: {
+ TelemetryField.SESSION_ID.value: session_id,
+ CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid,
+ **compilation_stage_summary,
},
}
self.send(message)
@@ -441,3 +478,60 @@ def send_large_query_optimization_skipped_telemetry(
},
}
self.send(message)
+
+ def send_temp_table_cleanup_telemetry(
+ self,
+ session_id: str,
+ temp_table_cleaner_enabled: bool,
+ num_temp_tables_cleaned: int,
+ num_temp_tables_created: int,
+ ) -> None:
+ message = {
+ **self._create_basic_telemetry_data(
+ TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value
+ ),
+ TelemetryField.KEY_DATA.value: {
+ TelemetryField.SESSION_ID.value: session_id,
+ TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled,
+ TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned,
+ TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created,
+ },
+ }
+ self.send(message)
+
+ def send_temp_table_cleanup_abnormal_exception_telemetry(
+ self,
+ session_id: str,
+ table_name: str,
+ exception_message: str,
+ ) -> None:
+ message = {
+ **self._create_basic_telemetry_data(
+ TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value
+ ),
+ TelemetryField.KEY_DATA.value: {
+ TelemetryField.SESSION_ID.value: session_id,
+ TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name,
+ TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message,
+ },
+ }
+ self.send(message)
+
+ def send_large_query_breakdown_update_complexity_bounds(
+ self, session_id: int, lower_bound: int, upper_bound: int
+ ):
+ message = {
+ **self._create_basic_telemetry_data(
+ CompilationStageTelemetryField.TYPE_LARGE_QUERY_BREAKDOWN_UPDATE_COMPLEXITY_BOUNDS.value
+ ),
+ TelemetryField.KEY_DATA.value: {
+ TelemetryField.SESSION_ID.value: session_id,
+ TelemetryField.KEY_DATA.value: {
+ CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: (
+ lower_bound,
+ upper_bound,
+ ),
+ },
+ },
+ }
+ self.send(message)
diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
index b9055c6fc58..4fa17498d34 100644
--- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
+++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
@@ -4,9 +4,7 @@
import logging
import weakref
from collections import defaultdict
-from queue import Empty, Queue
-from threading import Event, Thread
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable
@@ -33,12 +31,6 @@ def __init__(self, session: "Session") -> None:
# to its reference count for later temp table management
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
- # unused temp table will be put into the queue for cleanup
- self.queue: Queue = Queue()
- # thread for removing temp tables (running DROP TABLE sql)
- self.cleanup_thread: Optional[Thread] = None
- # An event managing a flag that indicates whether the cleaner is started
- self.stop_event = Event()
def add(self, table: SnowflakeTable) -> None:
self.ref_count_map[table.name] += 1
@@ -46,61 +38,60 @@ def add(self, table: SnowflakeTable) -> None:
# and this table will be dropped finally
_ = weakref.finalize(table, self._delete_ref_count, table.name)
- def _delete_ref_count(self, name: str) -> None:
+ def _delete_ref_count(self, name: str) -> None: # pragma: no cover
"""
Decrements the reference count of a temporary table,
and if the count reaches zero, puts this table in the queue for cleanup.
"""
self.ref_count_map[name] -= 1
if self.ref_count_map[name] == 0:
- self.ref_count_map.pop(name)
- # clean up
- self.queue.put(name)
+ if self.session.auto_clean_up_temp_table_enabled:
+ self.drop_table(name)
elif self.ref_count_map[name] < 0:
logging.debug(
f"Unexpected reference count {self.ref_count_map[name]} for table {name}"
)
- def process_cleanup(self) -> None:
- while not self.stop_event.is_set():
- try:
- # it's non-blocking after timeout and become interruptable with stop_event
- # it will raise an `Empty` exception if queue is empty after timeout,
- # then we catch this exception and avoid breaking loop
- table_name = self.queue.get(timeout=1)
- self.drop_table(table_name)
- except Empty:
- continue
-
- def drop_table(self, name: str) -> None:
+ def drop_table(self, name: str) -> None: # pragma: no cover
common_log_text = f"temp table {name} in session {self.session.session_id}"
- logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}")
+ logging.debug(f"Ready to drop {common_log_text}")
+ query_id = None
try:
- # TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported
- with self.session._conn._conn.cursor() as cursor:
- cursor.execute(
- f"drop table if exists {name} /* internal query to drop unused temp table */",
- _statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name},
+ async_job = self.session.sql(
+ f"drop table if exists {name} /* internal query to drop unused temp table */",
+ )._internal_collect_with_tag_no_telemetry(
+ block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}
+ )
+ query_id = async_job.query_id
+ logging.debug(f"Dropping {common_log_text} with query id {query_id}")
+ except Exception as ex: # pragma: no cover
+ warning_message = f"Failed to drop {common_log_text}, exception: {ex}"
+ logging.warning(warning_message)
+ if query_id is None:
+ # If no query_id is available, it means the query haven't been accepted by gs,
+ # and it won't occur in our job_etl_view, send a separate telemetry for recording.
+ self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry(
+ self.session.session_id,
+ name,
+ str(ex),
)
- logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}")
- except Exception as ex:
- logging.warning(
- f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}"
- ) # pragma: no cover
-
- def is_alive(self) -> bool:
- return self.cleanup_thread is not None and self.cleanup_thread.is_alive()
-
- def start(self) -> None:
- self.stop_event.clear()
- if not self.is_alive():
- self.cleanup_thread = Thread(target=self.process_cleanup)
- self.cleanup_thread.start()
def stop(self) -> None:
"""
- The cleaner will stop immediately and leave unfinished temp tables in the queue.
+ Stops the cleaner (no-op) and sends the telemetry.
"""
- self.stop_event.set()
- if self.is_alive():
- self.cleanup_thread.join()
+ self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry(
+ self.session.session_id,
+ temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled,
+ num_temp_tables_cleaned=self.num_temp_tables_cleaned,
+ num_temp_tables_created=self.num_temp_tables_created,
+ )
+
+ @property
+ def num_temp_tables_created(self) -> int:
+ return len(self.ref_count_map)
+
+ @property
+ def num_temp_tables_cleaned(self) -> int:
+ # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled
+ return sum(v == 0 for v in self.ref_count_map.values())
diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py
index 9e8d4d0d721..b384931cb89 100644
--- a/src/snowflake/snowpark/mock/_connection.py
+++ b/src/snowflake/snowpark/mock/_connection.py
@@ -6,6 +6,7 @@
import functools
import json
import logging
+import threading
import uuid
from copy import copy
from decimal import Decimal
@@ -91,35 +92,39 @@ def __init__(self, conn: "MockServerConnection") -> None:
self.table_registry = {}
self.view_registry = {}
self.conn = conn
+ self._lock = self.conn.get_lock()
def is_existing_table(self, name: Union[str, Iterable[str]]) -> bool:
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- qualified_name = get_fully_qualified_name(
- name, current_schema, current_database
- )
- return qualified_name in self.table_registry
+ with self._lock:
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ qualified_name = get_fully_qualified_name(
+ name, current_schema, current_database
+ )
+ return qualified_name in self.table_registry
def is_existing_view(self, name: Union[str, Iterable[str]]) -> bool:
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- qualified_name = get_fully_qualified_name(
- name, current_schema, current_database
- )
- return qualified_name in self.view_registry
+ with self._lock:
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ qualified_name = get_fully_qualified_name(
+ name, current_schema, current_database
+ )
+ return qualified_name in self.view_registry
def read_table(self, name: Union[str, Iterable[str]]) -> TableEmulator:
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- qualified_name = get_fully_qualified_name(
- name, current_schema, current_database
- )
- if qualified_name in self.table_registry:
- return copy(self.table_registry[qualified_name])
- else:
- raise SnowparkLocalTestingException(
- f"Object '{name}' does not exist or not authorized."
+ with self._lock:
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ qualified_name = get_fully_qualified_name(
+ name, current_schema, current_database
)
+ if qualified_name in self.table_registry:
+ return copy(self.table_registry[qualified_name])
+ else:
+ raise SnowparkLocalTestingException(
+ f"Object '{name}' does not exist or not authorized."
+ )
def write_table(
self,
@@ -128,127 +133,155 @@ def write_table(
mode: SaveMode,
column_names: Optional[List[str]] = None,
) -> List[Row]:
- for column in table.columns:
- if not table[column].sf_type.nullable and table[column].isnull().any():
- raise SnowparkLocalTestingException(
- "NULL result in a non-nullable column"
- )
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- name = get_fully_qualified_name(name, current_schema, current_database)
- table = copy(table)
- if mode == SaveMode.APPEND:
- if name in self.table_registry:
- target_table = self.table_registry[name]
- input_schema = table.columns.to_list()
- existing_schema = target_table.columns.to_list()
-
- if not column_names: # append with column_order being index
- if len(input_schema) != len(existing_schema):
- raise SnowparkLocalTestingException(
- f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}"
- )
- # temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1
- table.columns = range(table.shape[1])
- target_table.columns = range(target_table.shape[1])
- else: # append with column_order being name
- if invalid_cols := set(input_schema) - set(existing_schema):
- identifiers = "', '".join(
- unquote_if_quoted(id) for id in invalid_cols
- )
- raise SnowparkLocalTestingException(
- f"table contains invalid identifier '{identifiers}'"
- )
- invalid_non_nullable_cols = []
- for missing_col in set(existing_schema) - set(input_schema):
- if target_table[missing_col].sf_type.nullable:
- table[missing_col] = None
- table.sf_types[missing_col] = target_table[
- missing_col
- ].sf_type
- else:
- invalid_non_nullable_cols.append(missing_col)
- if invalid_non_nullable_cols:
- identifiers = "', '".join(
- unquote_if_quoted(id)
- for id in invalid_non_nullable_cols
- )
- raise SnowparkLocalTestingException(
- f"NULL result in a non-nullable column '{identifiers}'"
- )
-
- self.table_registry[name] = pandas.concat(
- [target_table, table], ignore_index=True
- )
- self.table_registry[name].columns = existing_schema
- self.table_registry[name].sf_types = target_table.sf_types
- else:
- self.table_registry[name] = table
- elif mode == SaveMode.IGNORE:
- if name not in self.table_registry:
- self.table_registry[name] = table
- elif mode == SaveMode.OVERWRITE:
- self.table_registry[name] = table
- elif mode == SaveMode.ERROR_IF_EXISTS:
- if name in self.table_registry:
- raise SnowparkLocalTestingException(f"Table {name} already exists")
- else:
- self.table_registry[name] = table
- elif mode == SaveMode.TRUNCATE:
- if name in self.table_registry:
- target_table = self.table_registry[name]
- input_schema = set(table.columns.to_list())
- existing_schema = set(target_table.columns.to_list())
- # input is a subset of existing schema and all missing columns are nullable
- if input_schema.issubset(existing_schema) and all(
- target_table[col].sf_type.nullable
- for col in set(existing_schema - input_schema)
+ with self._lock:
+ for column in table.columns:
+ if (
+ not table[column].sf_type.nullable
+ and table[column].isnull().any()
):
- for col in set(existing_schema - input_schema):
- table[col] = ColumnEmulator(
- data=[None] * table.shape[0],
- sf_type=target_table[col].sf_type,
- dtype=object,
- )
+ raise SnowparkLocalTestingException(
+ "NULL result in a non-nullable column"
+ )
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ name = get_fully_qualified_name(name, current_schema, current_database)
+ table = copy(table)
+ if mode == SaveMode.APPEND:
+ if name in self.table_registry:
+ target_table = self.table_registry[name]
+ input_schema = table.columns.to_list()
+ existing_schema = target_table.columns.to_list()
+
+ if not column_names: # append with column_order being index
+ if len(input_schema) != len(existing_schema):
+ raise SnowparkLocalTestingException(
+ f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}"
+ )
+ # temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1
+ table.columns = range(table.shape[1])
+ target_table.columns = range(target_table.shape[1])
+ else: # append with column_order being name
+ if invalid_cols := set(input_schema) - set(existing_schema):
+ identifiers = "', '".join(
+ unquote_if_quoted(id) for id in invalid_cols
+ )
+ raise SnowparkLocalTestingException(
+ f"table contains invalid identifier '{identifiers}'"
+ )
+ invalid_non_nullable_cols = []
+ for missing_col in set(existing_schema) - set(input_schema):
+ if target_table[missing_col].sf_type.nullable:
+ table[missing_col] = None
+ table.sf_types[missing_col] = target_table[
+ missing_col
+ ].sf_type
+ else:
+ invalid_non_nullable_cols.append(missing_col)
+ if invalid_non_nullable_cols:
+ identifiers = "', '".join(
+ unquote_if_quoted(id)
+ for id in invalid_non_nullable_cols
+ )
+ raise SnowparkLocalTestingException(
+ f"NULL result in a non-nullable column '{identifiers}'"
+ )
+
+ self.table_registry[name] = pandas.concat(
+ [target_table, table], ignore_index=True
+ )
+ self.table_registry[name].columns = existing_schema
+ self.table_registry[name].sf_types = target_table.sf_types
else:
+ self.table_registry[name] = table
+ elif mode == SaveMode.IGNORE:
+ if name not in self.table_registry:
+ self.table_registry[name] = table
+ elif mode == SaveMode.OVERWRITE:
+ self.table_registry[name] = table
+ elif mode == SaveMode.ERROR_IF_EXISTS:
+ if name in self.table_registry:
raise SnowparkLocalTestingException(
- f"Cannot truncate because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}"
+ f"Table {name} already exists"
)
- table.sf_types_by_col_index = target_table.sf_types_by_col_index
- table = table.reindex(columns=target_table.columns)
- self.table_registry[name] = table
- else:
- raise SnowparkLocalTestingException(f"Unrecognized mode: {mode}")
- return [
- Row(status=f"Table {name} successfully created.")
- ] # TODO: match message
+ else:
+ self.table_registry[name] = table
+ elif mode == SaveMode.TRUNCATE:
+ if name in self.table_registry:
+ target_table = self.table_registry[name]
+ input_schema = set(table.columns.to_list())
+ existing_schema = set(target_table.columns.to_list())
+ # input is a subset of existing schema and all missing columns are nullable
+ if input_schema.issubset(existing_schema) and all(
+ target_table[col].sf_type.nullable
+ for col in set(existing_schema - input_schema)
+ ):
+ for col in set(existing_schema - input_schema):
+ table[col] = ColumnEmulator(
+ data=[None] * table.shape[0],
+ sf_type=target_table[col].sf_type,
+ dtype=object,
+ )
+ else:
+ raise SnowparkLocalTestingException(
+ f"Cannot truncate because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}"
+ )
+ table.sf_types_by_col_index = target_table.sf_types_by_col_index
+ table = table.reindex(columns=target_table.columns)
+ self.table_registry[name] = table
+ else:
+ raise SnowparkLocalTestingException(f"Unrecognized mode: {mode}")
+ return [
+ Row(status=f"Table {name} successfully created.")
+ ] # TODO: match message
def drop_table(self, name: Union[str, Iterable[str]]) -> None:
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- name = get_fully_qualified_name(name, current_schema, current_database)
- if name in self.table_registry:
- self.table_registry.pop(name)
+ with self._lock:
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ name = get_fully_qualified_name(name, current_schema, current_database)
+ if name in self.table_registry:
+ self.table_registry.pop(name)
def create_or_replace_view(
self, execution_plan: MockExecutionPlan, name: Union[str, Iterable[str]]
):
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- name = get_fully_qualified_name(name, current_schema, current_database)
- self.view_registry[name] = execution_plan
+ with self._lock:
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ name = get_fully_qualified_name(name, current_schema, current_database)
+ self.view_registry[name] = execution_plan
def get_review(self, name: Union[str, Iterable[str]]) -> MockExecutionPlan:
- current_schema = self.conn._get_current_parameter("schema")
- current_database = self.conn._get_current_parameter("database")
- name = get_fully_qualified_name(name, current_schema, current_database)
- if name in self.view_registry:
- return self.view_registry[name]
- raise SnowparkLocalTestingException(f"View {name} does not exist")
+ with self._lock:
+ current_schema = self.conn._get_current_parameter("schema")
+ current_database = self.conn._get_current_parameter("database")
+ name = get_fully_qualified_name(name, current_schema, current_database)
+ if name in self.view_registry:
+ return self.view_registry[name]
+ raise SnowparkLocalTestingException(f"View {name} does not exist")
+
+ def read_view_if_exists(
+ self, name: Union[str, Iterable[str]]
+ ) -> Optional[MockExecutionPlan]:
+ """Method to atomically read a view if it exists. Returns None if the view does not exist."""
+ with self._lock:
+ if self.is_existing_view(name):
+ return self.get_review(name)
+ return None
+
+ def read_table_if_exists(
+ self, name: Union[str, Iterable[str]]
+ ) -> Optional[TableEmulator]:
+ """Method to atomically read a table if it exists. Returns None if the table does not exist."""
+ with self._lock:
+ if self.is_existing_table(name):
+ return self.read_table(name)
+ return None
def __init__(self, options: Optional[Dict[str, Any]] = None) -> None:
self._conn = MockedSnowflakeConnection()
self._cursor = Mock()
+ self._lock = threading.RLock()
self._lower_case_parameters = {}
self.remove_query_listener = Mock()
self.add_query_listener = Mock()
@@ -301,7 +334,7 @@ def log_not_supported_error(
warning_logger: Optional[logging.Logger] = None,
):
"""
- send telemetry to oob servie, can raise error or logging a warning based upon the input
+ send telemetry to oob service, can raise error or logging a warning based upon the input
Args:
external_feature_name: customer facing feature name, this information is used to raise error
@@ -323,25 +356,31 @@ def log_not_supported_error(
def _get_client_side_session_parameter(self, name: str, default_value: Any) -> Any:
# mock implementation
- return (
- self._conn._session_parameters.get(name, default_value)
- if self._conn._session_parameters
- else default_value
- )
+ with self._lock:
+ return (
+ self._conn._session_parameters.get(name, default_value)
+ if self._conn._session_parameters
+ else default_value
+ )
def get_session_id(self) -> int:
return 1
+ def get_lock(self):
+ return self._lock
+
def close(self) -> None:
- if self._conn:
- self._conn.close()
+ with self._lock:
+ if self._conn:
+ self._conn.close()
def is_closed(self) -> bool:
return self._conn.is_closed()
def _get_current_parameter(self, param: str, quoted: bool = True) -> Optional[str]:
try:
- name = getattr(self, f"_active_{param}", None)
+ with self._lock:
+ name = getattr(self, f"_active_{param}", None)
if name and len(name) >= 2 and name[0] == name[-1] == '"':
# it is a quoted identifier, return the original value
return name
diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py
index edf9ffc68b3..3842f6fda34 100644
--- a/src/snowflake/snowpark/mock/_functions.py
+++ b/src/snowflake/snowpark/mock/_functions.py
@@ -10,6 +10,7 @@
import operator
import re
import string
+import threading
from decimal import Decimal
from functools import partial, reduce
from numbers import Real
@@ -130,14 +131,17 @@ def __call__(self, *args, input_data=None, row_number=None, **kwargs):
class MockedFunctionRegistry:
_instance = None
+ _lock_init = threading.Lock()
def __init__(self) -> None:
self._registry = dict()
+ self._lock = threading.RLock()
@classmethod
def get_or_create(cls) -> "MockedFunctionRegistry":
- if cls._instance is None:
- cls._instance = MockedFunctionRegistry()
+ with cls._lock_init:
+ if cls._instance is None:
+ cls._instance = MockedFunctionRegistry()
return cls._instance
def get_function(
@@ -151,10 +155,11 @@ def get_function(
distinct = func.is_distinct
func_name = func_name.lower()
- if func_name not in self._registry:
- return None
+ with self._lock:
+ if func_name not in self._registry:
+ return None
- function = self._registry[func_name]
+ function = self._registry[func_name]
return function.distinct if distinct else function
@@ -169,7 +174,8 @@ def register(
snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__
)
mocked_function = MockedFunction(name, func_implementation, *args, **kwargs)
- self._registry[name] = mocked_function
+ with self._lock:
+ self._registry[name] = mocked_function
return mocked_function
def unregister(
@@ -180,8 +186,9 @@ def unregister(
snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__
)
- if name in self._registry:
- del self._registry[name]
+ with self._lock:
+ if name in self._registry:
+ del self._registry[name]
class LocalTimezone:
diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py
index 11e54802eea..aa86b2598d6 100644
--- a/src/snowflake/snowpark/mock/_plan.py
+++ b/src/snowflake/snowpark/mock/_plan.py
@@ -357,18 +357,21 @@ def handle_function_expression(
current_row=None,
):
func = MockedFunctionRegistry.get_or_create().get_function(exp)
+ connection_lock = analyzer.session._conn.get_lock()
if func is None:
- current_schema = analyzer.session.get_current_schema()
- current_database = analyzer.session.get_current_database()
+ with connection_lock:
+ current_schema = analyzer.session.get_current_schema()
+ current_database = analyzer.session.get_current_database()
udf_name = get_fully_qualified_name(exp.name, current_schema, current_database)
# If udf name in the registry then this is a udf, not an actual function
- if udf_name in analyzer.session.udf._registry:
- exp.udf_name = udf_name
- return handle_udf_expression(
- exp, input_data, analyzer, expr_to_alias, current_row
- )
+ with connection_lock:
+ if udf_name in analyzer.session.udf._registry:
+ exp.udf_name = udf_name
+ return handle_udf_expression(
+ exp, input_data, analyzer, expr_to_alias, current_row
+ )
if exp.api_call_source == "functions.call_udf":
raise SnowparkLocalTestingException(
@@ -463,9 +466,12 @@ def handle_udf_expression(
):
udf_registry = analyzer.session.udf
udf_name = exp.udf_name
- udf = udf_registry.get_udf(udf_name)
+ connection_lock = analyzer.session._conn.get_lock()
+ with connection_lock:
+ udf = udf_registry.get_udf(udf_name)
+ udf_imports = udf_registry.get_udf_imports(udf_name)
- with ImportContext(udf_registry.get_udf_imports(udf_name)):
+ with ImportContext(udf_imports):
# Resolve handler callable
if type(udf.func) is tuple:
module_name, handler_name = udf.func
@@ -556,6 +562,7 @@ def execute_mock_plan(
analyzer = plan.analyzer
entity_registry = analyzer.session._conn.entity_registry
+ connection_lock = analyzer.session._conn.get_lock()
if isinstance(source_plan, SnowflakeValues):
table = TableEmulator(
@@ -728,18 +735,20 @@ def execute_mock_plan(
return res_df
if isinstance(source_plan, MockSelectableEntity):
entity_name = source_plan.entity.name
- if entity_registry.is_existing_table(entity_name):
- return entity_registry.read_table(entity_name)
- elif entity_registry.is_existing_view(entity_name):
- execution_plan = entity_registry.get_review(entity_name)
+ table = entity_registry.read_table_if_exists(entity_name)
+ if table is not None:
+ return table
+
+ execution_plan = entity_registry.read_view_if_exists(entity_name)
+ if execution_plan is not None:
res_df = execute_mock_plan(execution_plan, expr_to_alias)
return res_df
- else:
- db_schme_table = parse_table_name(entity_name)
- table = ".".join([part.strip("\"'") for part in db_schme_table[:3]])
- raise SnowparkLocalTestingException(
- f"Object '{table}' does not exist or not authorized."
- )
+
+ db_schema_table = parse_table_name(entity_name)
+ table = ".".join([part.strip("\"'") for part in db_schema_table[:3]])
+ raise SnowparkLocalTestingException(
+ f"Object '{table}' does not exist or not authorized."
+ )
if isinstance(source_plan, Aggregate):
child_rf = execute_mock_plan(source_plan.child, expr_to_alias)
if (
@@ -1111,28 +1120,30 @@ def outer_join(base_df):
)
if isinstance(source_plan, SnowflakeTable):
entity_name = source_plan.name
- if entity_registry.is_existing_table(entity_name):
- return entity_registry.read_table(entity_name)
- elif entity_registry.is_existing_view(entity_name):
- execution_plan = entity_registry.get_review(entity_name)
+ table = entity_registry.read_table_if_exists(entity_name)
+ if table is not None:
+ return table
+
+ execution_plan = entity_registry.read_view_if_exists(entity_name)
+ if execution_plan is not None:
res_df = execute_mock_plan(execution_plan, expr_to_alias)
return res_df
- else:
- obj_name_tuple = parse_table_name(entity_name)
- obj_name = obj_name_tuple[-1]
- obj_schema = (
- obj_name_tuple[-2]
- if len(obj_name_tuple) > 1
- else analyzer.session.get_current_schema()
- )
- obj_database = (
- obj_name_tuple[-3]
- if len(obj_name_tuple) > 2
- else analyzer.session.get_current_database()
- )
- raise SnowparkLocalTestingException(
- f"Object '{obj_database[1:-1]}.{obj_schema[1:-1]}.{obj_name[1:-1]}' does not exist or not authorized."
- )
+
+ obj_name_tuple = parse_table_name(entity_name)
+ obj_name = obj_name_tuple[-1]
+ obj_schema = (
+ obj_name_tuple[-2]
+ if len(obj_name_tuple) > 1
+ else analyzer.session.get_current_schema()
+ )
+ obj_database = (
+ obj_name_tuple[-3]
+ if len(obj_name_tuple) > 2
+ else analyzer.session.get_current_database()
+ )
+ raise SnowparkLocalTestingException(
+ f"Object '{obj_database[1:-1]}.{obj_schema[1:-1]}.{obj_name[1:-1]}' does not exist or not authorized."
+ )
if isinstance(source_plan, Sample):
res_df = execute_mock_plan(source_plan.child, expr_to_alias)
@@ -1159,272 +1170,283 @@ def outer_join(base_df):
return from_df
if isinstance(source_plan, TableUpdate):
- target = entity_registry.read_table(source_plan.table_name)
- ROW_ID = "row_id_" + generate_random_alphanumeric()
- target.insert(0, ROW_ID, range(len(target)))
+ # since we are modifying the table, we need to ensure that no other thread
+ # reads the table until it is updated
+ with connection_lock:
+ target = entity_registry.read_table(source_plan.table_name)
+ ROW_ID = "row_id_" + generate_random_alphanumeric()
+ target.insert(0, ROW_ID, range(len(target)))
+
+ if source_plan.source_data:
+ # Calculate cartesian product
+ source = execute_mock_plan(source_plan.source_data, expr_to_alias)
+ cartesian_product = target.merge(source, on=None, how="cross")
+ cartesian_product.sf_types.update(target.sf_types)
+ cartesian_product.sf_types.update(source.sf_types)
+ intermediate = cartesian_product
+ else:
+ intermediate = target
- if source_plan.source_data:
- # Calculate cartesian product
- source = execute_mock_plan(source_plan.source_data, expr_to_alias)
- cartesian_product = target.merge(source, on=None, how="cross")
- cartesian_product.sf_types.update(target.sf_types)
- cartesian_product.sf_types.update(source.sf_types)
- intermediate = cartesian_product
- else:
- intermediate = target
+ if source_plan.condition:
+ # Select rows to be updated based on condition
+ condition = calculate_expression(
+ source_plan.condition, intermediate, analyzer, expr_to_alias
+ ).fillna(value=False)
- if source_plan.condition:
- # Select rows to be updated based on condition
- condition = calculate_expression(
- source_plan.condition, intermediate, analyzer, expr_to_alias
- ).fillna(value=False)
-
- matched = target.apply(tuple, 1).isin(
- intermediate[condition][target.columns].apply(tuple, 1)
+ matched = target.apply(tuple, 1).isin(
+ intermediate[condition][target.columns].apply(tuple, 1)
+ )
+ matched.sf_type = ColumnType(BooleanType(), True)
+ matched_rows = target[matched]
+ intermediate = intermediate[condition]
+ else:
+ matched_rows = target
+
+ # Calculate multi_join
+ matched_count = intermediate[target.columns].value_counts(dropna=False)[
+ matched_rows.apply(tuple, 1)
+ ]
+ multi_joins = matched_count.where(lambda x: x > 1).count()
+
+ # Select rows that match the condition to be updated
+ rows_to_update = intermediate.drop_duplicates(
+ subset=matched_rows.columns, keep="first"
+ ).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update
+ drop=True
)
- matched.sf_type = ColumnType(BooleanType(), True)
- matched_rows = target[matched]
- intermediate = intermediate[condition]
- else:
- matched_rows = target
+ rows_to_update.sf_types = intermediate.sf_types
+
+ # Update rows in place
+ for attr, new_expr in source_plan.assignments.items():
+ column_name = analyzer.analyze(attr, expr_to_alias)
+ target_index = target.loc[rows_to_update[ROW_ID]].index
+ new_val = calculate_expression(
+ new_expr, rows_to_update, analyzer, expr_to_alias
+ )
+ new_val.index = target_index
+ target.loc[rows_to_update[ROW_ID], column_name] = new_val
- # Calculate multi_join
- matched_count = intermediate[target.columns].value_counts(dropna=False)[
- matched_rows.apply(tuple, 1)
- ]
- multi_joins = matched_count.where(lambda x: x > 1).count()
+ # Delete row_id
+ target = target.drop(ROW_ID, axis=1)
- # Select rows that match the condition to be updated
- rows_to_update = intermediate.drop_duplicates(
- subset=matched_rows.columns, keep="first"
- ).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update
- drop=True
- )
- rows_to_update.sf_types = intermediate.sf_types
-
- # Update rows in place
- for attr, new_expr in source_plan.assignments.items():
- column_name = analyzer.analyze(attr, expr_to_alias)
- target_index = target.loc[rows_to_update[ROW_ID]].index
- new_val = calculate_expression(
- new_expr, rows_to_update, analyzer, expr_to_alias
+ # Write result back to table
+ entity_registry.write_table(
+ source_plan.table_name, target, SaveMode.OVERWRITE
)
- new_val.index = target_index
- target.loc[rows_to_update[ROW_ID], column_name] = new_val
-
- # Delete row_id
- target = target.drop(ROW_ID, axis=1)
-
- # Write result back to table
- entity_registry.write_table(source_plan.table_name, target, SaveMode.OVERWRITE)
return [Row(len(rows_to_update), multi_joins)]
elif isinstance(source_plan, TableDelete):
- target = entity_registry.read_table(source_plan.table_name)
+ # since we are modifying the table, we need to ensure that no other thread
+ # reads the table until it is updated
+ with connection_lock:
+ target = entity_registry.read_table(source_plan.table_name)
+
+ if source_plan.source_data:
+ # Calculate cartesian product
+ source = execute_mock_plan(source_plan.source_data, expr_to_alias)
+ cartesian_product = target.merge(source, on=None, how="cross")
+ cartesian_product.sf_types.update(target.sf_types)
+ cartesian_product.sf_types.update(source.sf_types)
+ intermediate = cartesian_product
+ else:
+ intermediate = target
+
+ # Select rows to keep based on condition
+ if source_plan.condition:
+ condition = calculate_expression(
+ source_plan.condition, intermediate, analyzer, expr_to_alias
+ ).fillna(value=False)
+ intermediate = intermediate[condition]
+ matched = target.apply(tuple, 1).isin(
+ intermediate[target.columns].apply(tuple, 1)
+ )
+ matched.sf_type = ColumnType(BooleanType(), True)
+ rows_to_keep = target[~matched]
+ else:
+ rows_to_keep = target.head(0)
- if source_plan.source_data:
+ # Write rows to keep to table registry
+ entity_registry.write_table(
+ source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE
+ )
+ return [Row(len(target) - len(rows_to_keep))]
+ elif isinstance(source_plan, TableMerge):
+ # since we are modifying the table, we need to ensure that no other thread
+ # reads the table until it is updated
+ with connection_lock:
+ target = entity_registry.read_table(source_plan.table_name)
+ ROW_ID = "row_id_" + generate_random_alphanumeric()
+ SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric()
# Calculate cartesian product
- source = execute_mock_plan(source_plan.source_data, expr_to_alias)
+ source = execute_mock_plan(source_plan.source, expr_to_alias)
+
+ # Insert row_id and source row_id
+ target.insert(0, ROW_ID, range(len(target)))
+ source.insert(0, SOURCE_ROW_ID, range(len(source)))
+
cartesian_product = target.merge(source, on=None, how="cross")
cartesian_product.sf_types.update(target.sf_types)
cartesian_product.sf_types.update(source.sf_types)
- intermediate = cartesian_product
- else:
- intermediate = target
-
- # Select rows to keep based on condition
- if source_plan.condition:
- condition = calculate_expression(
- source_plan.condition, intermediate, analyzer, expr_to_alias
- ).fillna(value=False)
- intermediate = intermediate[condition]
- matched = target.apply(tuple, 1).isin(
- intermediate[target.columns].apply(tuple, 1)
+ join_condition = calculate_expression(
+ source_plan.join_expr, cartesian_product, analyzer, expr_to_alias
)
- matched.sf_type = ColumnType(BooleanType(), True)
- rows_to_keep = target[~matched]
- else:
- rows_to_keep = target.head(0)
+ join_result = cartesian_product[join_condition]
+ join_result.sf_types = cartesian_product.sf_types
+
+ # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if
+ # (1) A target row is selected to be updated with multiple values OR
+ # (2) A target row is selected to be both updated and deleted
+
+ inserted_rows = []
+ insert_clause_specified = (
+ update_clause_specified
+ ) = delete_clause_specified = False
+ inserted_row_idx = set() # source_row_id
+ deleted_row_idx = set()
+ updated_row_idx = set()
+ for clause in source_plan.clauses:
+ if isinstance(clause, UpdateMergeExpression):
+ update_clause_specified = True
+ # Select rows to update
+ if clause.condition:
+ condition = calculate_expression(
+ clause.condition, join_result, analyzer, expr_to_alias
+ ).fillna(value=False)
+ rows_to_update = join_result[condition]
+ else:
+ rows_to_update = join_result
- # Write rows to keep to table registry
- entity_registry.write_table(
- source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE
- )
- return [Row(len(target) - len(rows_to_keep))]
- elif isinstance(source_plan, TableMerge):
- target = entity_registry.read_table(source_plan.table_name)
- ROW_ID = "row_id_" + generate_random_alphanumeric()
- SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric()
- # Calculate cartesian product
- source = execute_mock_plan(source_plan.source, expr_to_alias)
-
- # Insert row_id and source row_id
- target.insert(0, ROW_ID, range(len(target)))
- source.insert(0, SOURCE_ROW_ID, range(len(source)))
-
- cartesian_product = target.merge(source, on=None, how="cross")
- cartesian_product.sf_types.update(target.sf_types)
- cartesian_product.sf_types.update(source.sf_types)
- join_condition = calculate_expression(
- source_plan.join_expr, cartesian_product, analyzer, expr_to_alias
- )
- join_result = cartesian_product[join_condition]
- join_result.sf_types = cartesian_product.sf_types
-
- # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if
- # (1) A target row is selected to be updated with multiple values OR
- # (2) A target row is selected to be both updated and deleted
-
- inserted_rows = []
- insert_clause_specified = (
- update_clause_specified
- ) = delete_clause_specified = False
- inserted_row_idx = set() # source_row_id
- deleted_row_idx = set()
- updated_row_idx = set()
- for clause in source_plan.clauses:
- if isinstance(clause, UpdateMergeExpression):
- update_clause_specified = True
- # Select rows to update
- if clause.condition:
- condition = calculate_expression(
- clause.condition, join_result, analyzer, expr_to_alias
- ).fillna(value=False)
- rows_to_update = join_result[condition]
- else:
- rows_to_update = join_result
+ rows_to_update = rows_to_update[
+ ~rows_to_update[ROW_ID]
+ .isin(updated_row_idx.union(deleted_row_idx))
+ .values
+ ]
- rows_to_update = rows_to_update[
- ~rows_to_update[ROW_ID]
- .isin(updated_row_idx.union(deleted_row_idx))
- .values
- ]
+ # Update rows in place
+ for attr, new_expr in clause.assignments.items():
+ column_name = analyzer.analyze(attr, expr_to_alias)
+ target_index = target.loc[rows_to_update[ROW_ID]].index
+ new_val = calculate_expression(
+ new_expr, rows_to_update, analyzer, expr_to_alias
+ )
+ new_val.index = target_index
+ target.loc[rows_to_update[ROW_ID], column_name] = new_val
+
+ # Update updated row id set
+ for _, row in rows_to_update.iterrows():
+ updated_row_idx.add(row[ROW_ID])
+
+ elif isinstance(clause, DeleteMergeExpression):
+ delete_clause_specified = True
+ # Select rows to delete
+ if clause.condition:
+ condition = calculate_expression(
+ clause.condition, join_result, analyzer, expr_to_alias
+ ).fillna(value=False)
+ intermediate = join_result[condition]
+ else:
+ intermediate = join_result
- # Update rows in place
- for attr, new_expr in clause.assignments.items():
- column_name = analyzer.analyze(attr, expr_to_alias)
- target_index = target.loc[rows_to_update[ROW_ID]].index
- new_val = calculate_expression(
- new_expr, rows_to_update, analyzer, expr_to_alias
+ matched = target.apply(tuple, 1).isin(
+ intermediate[target.columns].apply(tuple, 1)
)
- new_val.index = target_index
- target.loc[rows_to_update[ROW_ID], column_name] = new_val
-
- # Update updated row id set
- for _, row in rows_to_update.iterrows():
- updated_row_idx.add(row[ROW_ID])
-
- elif isinstance(clause, DeleteMergeExpression):
- delete_clause_specified = True
- # Select rows to delete
- if clause.condition:
- condition = calculate_expression(
- clause.condition, join_result, analyzer, expr_to_alias
- ).fillna(value=False)
- intermediate = join_result[condition]
- else:
- intermediate = join_result
+ matched.sf_type = ColumnType(BooleanType(), True)
- matched = target.apply(tuple, 1).isin(
- intermediate[target.columns].apply(tuple, 1)
- )
- matched.sf_type = ColumnType(BooleanType(), True)
+ # Update deleted row id set
+ for _, row in target[matched].iterrows():
+ deleted_row_idx.add(row[ROW_ID])
- # Update deleted row id set
- for _, row in target[matched].iterrows():
- deleted_row_idx.add(row[ROW_ID])
+ # Delete rows in place
+ target = target[~matched]
- # Delete rows in place
- target = target[~matched]
+ elif isinstance(clause, InsertMergeExpression):
+ insert_clause_specified = True
+ # calculate unmatched rows in the source
+ matched = source.apply(tuple, 1).isin(
+ join_result[source.columns].apply(tuple, 1)
+ )
+ matched.sf_type = ColumnType(BooleanType(), True)
+ unmatched_rows_in_source = source[~matched]
+
+ # select unmatched rows that qualify the condition
+ if clause.condition:
+ condition = calculate_expression(
+ clause.condition,
+ unmatched_rows_in_source,
+ analyzer,
+ expr_to_alias,
+ ).fillna(value=False)
+ unmatched_rows_in_source = unmatched_rows_in_source[condition]
+
+ # filter out the unmatched rows that have been inserted in previous clauses
+ unmatched_rows_in_source = unmatched_rows_in_source[
+ ~unmatched_rows_in_source[SOURCE_ROW_ID]
+ .isin(inserted_row_idx)
+ .values
+ ]
- elif isinstance(clause, InsertMergeExpression):
- insert_clause_specified = True
- # calculate unmatched rows in the source
- matched = source.apply(tuple, 1).isin(
- join_result[source.columns].apply(tuple, 1)
- )
- matched.sf_type = ColumnType(BooleanType(), True)
- unmatched_rows_in_source = source[~matched]
+ # update inserted row idx set
+ for _, row in unmatched_rows_in_source.iterrows():
+ inserted_row_idx.add(row[SOURCE_ROW_ID])
- # select unmatched rows that qualify the condition
- if clause.condition:
- condition = calculate_expression(
- clause.condition,
- unmatched_rows_in_source,
- analyzer,
- expr_to_alias,
- ).fillna(value=False)
- unmatched_rows_in_source = unmatched_rows_in_source[condition]
-
- # filter out the unmatched rows that have been inserted in previous clauses
- unmatched_rows_in_source = unmatched_rows_in_source[
- ~unmatched_rows_in_source[SOURCE_ROW_ID]
- .isin(inserted_row_idx)
- .values
- ]
+ # Calculate rows to insert
+ rows_to_insert = TableEmulator(
+ [], columns=target.drop(ROW_ID, axis=1).columns, dtype=object
+ )
+ rows_to_insert.sf_types = target.sf_types
+ if clause.keys:
+ # Keep track of specified columns
+ inserted_columns = set()
+ for k, v in zip(clause.keys, clause.values):
+ column_name = analyzer.analyze(k, expr_to_alias)
+ if column_name not in rows_to_insert.columns:
+ raise SnowparkLocalTestingException(
+ f"invalid identifier '{column_name}'"
+ )
+ inserted_columns.add(column_name)
+ new_val = calculate_expression(
+ v, unmatched_rows_in_source, analyzer, expr_to_alias
+ )
+ # pandas could do implicit type conversion, e.g. from datetime to timestamp
+ # reconstructing ColumnEmulator helps preserve the original date type
+ rows_to_insert[column_name] = ColumnEmulator(
+ new_val.values,
+ dtype=object,
+ sf_type=rows_to_insert[column_name].sf_type,
+ )
- # update inserted row idx set
- for _, row in unmatched_rows_in_source.iterrows():
- inserted_row_idx.add(row[SOURCE_ROW_ID])
+ # For unspecified columns, use None as default value
+ for unspecified_col in set(rows_to_insert.columns).difference(
+ inserted_columns
+ ):
+ rows_to_insert[unspecified_col].replace(
+ np.nan, None, inplace=True
+ )
- # Calculate rows to insert
- rows_to_insert = TableEmulator(
- [], columns=target.drop(ROW_ID, axis=1).columns, dtype=object
- )
- rows_to_insert.sf_types = target.sf_types
- if clause.keys:
- # Keep track of specified columns
- inserted_columns = set()
- for k, v in zip(clause.keys, clause.values):
- column_name = analyzer.analyze(k, expr_to_alias)
- if column_name not in rows_to_insert.columns:
+ else:
+ if len(clause.values) != len(rows_to_insert.columns):
raise SnowparkLocalTestingException(
- f"invalid identifier '{column_name}'"
+ f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}"
)
- inserted_columns.add(column_name)
- new_val = calculate_expression(
- v, unmatched_rows_in_source, analyzer, expr_to_alias
- )
- # pandas could do implicit type conversion, e.g. from datetime to timestamp
- # reconstructing ColumnEmulator helps preserve the original date type
- rows_to_insert[column_name] = ColumnEmulator(
- new_val.values,
- dtype=object,
- sf_type=rows_to_insert[column_name].sf_type,
- )
-
- # For unspecified columns, use None as default value
- for unspecified_col in set(rows_to_insert.columns).difference(
- inserted_columns
- ):
- rows_to_insert[unspecified_col].replace(
- np.nan, None, inplace=True
- )
-
- else:
- if len(clause.values) != len(rows_to_insert.columns):
- raise SnowparkLocalTestingException(
- f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}"
- )
- for col, v in zip(rows_to_insert.columns, clause.values):
- new_val = calculate_expression(
- v, unmatched_rows_in_source, analyzer, expr_to_alias
- )
- rows_to_insert[col] = new_val
+ for col, v in zip(rows_to_insert.columns, clause.values):
+ new_val = calculate_expression(
+ v, unmatched_rows_in_source, analyzer, expr_to_alias
+ )
+ rows_to_insert[col] = new_val
- inserted_rows.append(rows_to_insert)
+ inserted_rows.append(rows_to_insert)
- # Remove inserted ROW ID column
- target = target.drop(ROW_ID, axis=1)
+ # Remove inserted ROW ID column
+ target = target.drop(ROW_ID, axis=1)
- # Process inserted rows
- if inserted_rows:
- res = pd.concat([target] + inserted_rows)
- res.sf_types = target.sf_types
- else:
- res = target
+ # Process inserted rows
+ if inserted_rows:
+ res = pd.concat([target] + inserted_rows)
+ res.sf_types = target.sf_types
+ else:
+ res = target
- # Write the result back to table
- entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE)
+ # Write the result back to table
+ entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE)
# Generate metadata result
res = []
diff --git a/src/snowflake/snowpark/mock/_stage_registry.py b/src/snowflake/snowpark/mock/_stage_registry.py
index 7ed55d1cdc6..d4100606821 100644
--- a/src/snowflake/snowpark/mock/_stage_registry.py
+++ b/src/snowflake/snowpark/mock/_stage_registry.py
@@ -647,30 +647,34 @@ def __init__(self, conn: "MockServerConnection") -> None:
self._root_dir = tempfile.TemporaryDirectory()
self._stage_registry = {}
self._conn = conn
+ self._lock = conn.get_lock()
def create_or_replace_stage(self, stage_name):
- self._stage_registry[stage_name] = StageEntity(
- self._root_dir.name, stage_name, self._conn
- )
+ with self._lock:
+ self._stage_registry[stage_name] = StageEntity(
+ self._root_dir.name, stage_name, self._conn
+ )
def __getitem__(self, stage_name: str):
# the assumption here is that stage always exists
- if stage_name not in self._stage_registry:
- self.create_or_replace_stage(stage_name)
- return self._stage_registry[stage_name]
+ with self._lock:
+ if stage_name not in self._stage_registry:
+ self.create_or_replace_stage(stage_name)
+ return self._stage_registry[stage_name]
def put(
self, local_file_name: str, stage_location: str, overwrite: bool = False
) -> TableEmulator:
stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location)
# the assumption here is that stage always exists
- if stage_name not in self._stage_registry:
- self.create_or_replace_stage(stage_name)
- return self._stage_registry[stage_name].put_file(
- local_file_name=local_file_name,
- stage_prefix=stage_prefix,
- overwrite=overwrite,
- )
+ with self._lock:
+ if stage_name not in self._stage_registry:
+ self.create_or_replace_stage(stage_name)
+ return self._stage_registry[stage_name].put_file(
+ local_file_name=local_file_name,
+ stage_prefix=stage_prefix,
+ overwrite=overwrite,
+ )
def upload_stream(
self,
@@ -681,14 +685,15 @@ def upload_stream(
) -> Dict:
stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location)
# the assumption here is that stage always exists
- if stage_name not in self._stage_registry:
- self.create_or_replace_stage(stage_name)
- return self._stage_registry[stage_name].upload_stream(
- input_stream=input_stream,
- stage_prefix=stage_prefix,
- file_name=file_name,
- overwrite=overwrite,
- )
+ with self._lock:
+ if stage_name not in self._stage_registry:
+ self.create_or_replace_stage(stage_name)
+ return self._stage_registry[stage_name].upload_stream(
+ input_stream=input_stream,
+ stage_prefix=stage_prefix,
+ file_name=file_name,
+ overwrite=overwrite,
+ )
def get(
self,
@@ -701,14 +706,15 @@ def get(
f"Invalid stage {stage_location}, stage name should start with character '@'"
)
stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location)
- if stage_name not in self._stage_registry:
- self.create_or_replace_stage(stage_name)
-
- return self._stage_registry[stage_name].get_file(
- stage_location=stage_prefix,
- target_directory=target_directory,
- options=options,
- )
+ with self._lock:
+ if stage_name not in self._stage_registry:
+ self.create_or_replace_stage(stage_name)
+
+ return self._stage_registry[stage_name].get_file(
+ stage_location=stage_prefix,
+ target_directory=target_directory,
+ options=options,
+ )
def read_file(
self,
@@ -723,13 +729,14 @@ def read_file(
f"Invalid stage {stage_location}, stage name should start with character '@'"
)
stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location)
- if stage_name not in self._stage_registry:
- self.create_or_replace_stage(stage_name)
-
- return self._stage_registry[stage_name].read_file(
- stage_location=stage_prefix,
- format=format,
- schema=schema,
- analyzer=analyzer,
- options=options,
- )
+ with self._lock:
+ if stage_name not in self._stage_registry:
+ self.create_or_replace_stage(stage_name)
+
+ return self._stage_registry[stage_name].read_file(
+ stage_location=stage_prefix,
+ format=format,
+ schema=schema,
+ analyzer=analyzer,
+ options=options,
+ )
diff --git a/src/snowflake/snowpark/mock/_stored_procedure.py b/src/snowflake/snowpark/mock/_stored_procedure.py
index d93500da2e8..14abec358c2 100644
--- a/src/snowflake/snowpark/mock/_stored_procedure.py
+++ b/src/snowflake/snowpark/mock/_stored_procedure.py
@@ -154,9 +154,11 @@ def __init__(self, *args, **kwargs) -> None:
) # maps name to either the callable or a pair of str (module_name, callable_name)
self._sproc_level_imports = dict() # maps name to a set of file paths
self._session_level_imports = set()
+ self._lock = self._session._conn.get_lock()
def _clear_session_imports(self):
- self._session_level_imports.clear()
+ with self._lock:
+ self._session_level_imports.clear()
def _import_file(
self,
@@ -172,16 +174,17 @@ def _import_file(
imports specified.
"""
- absolute_module_path, module_name = extract_import_dir_and_module_name(
- file_path, self._session._conn.stage_registry, import_path
- )
+ with self._lock:
+ absolute_module_path, module_name = extract_import_dir_and_module_name(
+ file_path, self._session._conn.stage_registry, import_path
+ )
- if sproc_name:
- self._sproc_level_imports[sproc_name].add(absolute_module_path)
- else:
- self._session_level_imports.add(absolute_module_path)
+ if sproc_name:
+ self._sproc_level_imports[sproc_name].add(absolute_module_path)
+ else:
+ self._session_level_imports.add(absolute_module_path)
- return module_name
+ return module_name
def _do_register_sp(
self,
@@ -224,90 +227,96 @@ def _do_register_sp(
error_message="Registering anonymous sproc is not currently supported.",
raise_error=NotImplementedError,
)
- (
- sproc_name,
- is_pandas_udf,
- is_dataframe_input,
- return_type,
- input_types,
- opt_arg_defaults,
- ) = process_registration_inputs(
- self._session,
- TempObjectType.PROCEDURE,
- func,
- return_type,
- input_types,
- sp_name,
- anonymous,
- )
- current_schema = self._session.get_current_schema()
- current_database = self._session.get_current_database()
- sproc_name = get_fully_qualified_name(
- sproc_name, current_schema, current_database
- )
-
- check_python_runtime_version(self._session._runtime_version_from_requirement)
-
- if replace and if_not_exists:
- raise ValueError("options replace and if_not_exists are incompatible")
+ with self._lock:
+ (
+ sproc_name,
+ is_pandas_udf,
+ is_dataframe_input,
+ return_type,
+ input_types,
+ opt_arg_defaults,
+ ) = process_registration_inputs(
+ self._session,
+ TempObjectType.PROCEDURE,
+ func,
+ return_type,
+ input_types,
+ sp_name,
+ anonymous,
+ )
- if sproc_name in self._registry and if_not_exists:
- return self._registry[sproc_name]
+ current_schema = self._session.get_current_schema()
+ current_database = self._session.get_current_database()
+ sproc_name = get_fully_qualified_name(
+ sproc_name, current_schema, current_database
+ )
- if sproc_name in self._registry and not replace:
- raise SnowparkLocalTestingException(
- f"002002 (42710): SQL compilation error: \nObject '{sproc_name}' already exists.",
- error_code="1304",
+ check_python_runtime_version(
+ self._session._runtime_version_from_requirement
)
- if is_pandas_udf:
- raise TypeError("pandas stored procedure is not supported")
+ if replace and if_not_exists:
+ raise ValueError("options replace and if_not_exists are incompatible")
- if packages:
- pass # NO-OP
+ if sproc_name in self._registry and if_not_exists:
+ return self._registry[sproc_name]
- if imports is not None or type(func) is tuple:
- self._sproc_level_imports[sproc_name] = set()
+ if sproc_name in self._registry and not replace:
+ raise SnowparkLocalTestingException(
+ f"002002 (42710): SQL compilation error: \nObject '{sproc_name}' already exists.",
+ error_code="1304",
+ )
- if imports is not None:
- for _import in imports:
- if isinstance(_import, str):
- self._import_file(_import, sproc_name=sproc_name)
- elif isinstance(_import, tuple) and all(
- isinstance(item, str) for item in _import
- ):
- local_path, import_path = _import
- self._import_file(local_path, import_path, sproc_name=sproc_name)
- else:
- raise TypeError(
- "stored-proc-level import can only be a file path (str) or a tuple of the file path (str) and the import path (str)"
- )
+ if is_pandas_udf:
+ raise TypeError("pandas stored procedure is not supported")
- if type(func) is tuple: # register from file
- if sproc_name not in self._sproc_level_imports:
- self._sproc_level_imports[sproc_name] = set()
- module_name = self._import_file(func[0], sproc_name=sproc_name)
- func = (module_name, func[1])
+ if packages:
+ pass # NO-OP
- if sproc_name in self._sproc_level_imports:
- sproc_imports = self._sproc_level_imports[sproc_name]
- else:
- sproc_imports = copy(self._session_level_imports)
+ if imports is not None or type(func) is tuple:
+ self._sproc_level_imports[sproc_name] = set()
- sproc = MockStoredProcedure(
- func,
- return_type,
- input_types,
- sproc_name,
- sproc_imports,
- execute_as=execute_as,
- strict=strict,
- )
+ if imports is not None:
+ for _import in imports:
+ if isinstance(_import, str):
+ self._import_file(_import, sproc_name=sproc_name)
+ elif isinstance(_import, tuple) and all(
+ isinstance(item, str) for item in _import
+ ):
+ local_path, import_path = _import
+ self._import_file(
+ local_path, import_path, sproc_name=sproc_name
+ )
+ else:
+ raise TypeError(
+ "stored-proc-level import can only be a file path (str) or a tuple of the file path (str) and the import path (str)"
+ )
+
+ if type(func) is tuple: # register from file
+ if sproc_name not in self._sproc_level_imports:
+ self._sproc_level_imports[sproc_name] = set()
+ module_name = self._import_file(func[0], sproc_name=sproc_name)
+ func = (module_name, func[1])
+
+ if sproc_name in self._sproc_level_imports:
+ sproc_imports = self._sproc_level_imports[sproc_name]
+ else:
+ sproc_imports = copy(self._session_level_imports)
+
+ sproc = MockStoredProcedure(
+ func,
+ return_type,
+ input_types,
+ sproc_name,
+ sproc_imports,
+ execute_as=execute_as,
+ strict=strict,
+ )
- self._registry[sproc_name] = sproc
+ self._registry[sproc_name] = sproc
- return sproc
+ return sproc
def call(
self,
@@ -316,17 +325,18 @@ def call(
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
):
- current_schema = self._session.get_current_schema()
- current_database = self._session.get_current_database()
- sproc_name = get_fully_qualified_name(
- sproc_name, current_schema, current_database
- )
-
- if sproc_name not in self._registry:
- raise SnowparkLocalTestingException(
- f"Unknown function {sproc_name}. Stored procedure by that name does not exist."
+ with self._lock:
+ current_schema = self._session.get_current_schema()
+ current_database = self._session.get_current_database()
+ sproc_name = get_fully_qualified_name(
+ sproc_name, current_schema, current_database
)
- return self._registry[sproc_name](
- *args, session=session, statement_params=statement_params
- )
+ if sproc_name not in self._registry:
+ raise SnowparkLocalTestingException(
+ f"Unknown function {sproc_name}. Stored procedure by that name does not exist."
+ )
+
+ sproc = self._registry[sproc_name]
+
+ return sproc(*args, session=session, statement_params=statement_params)
diff --git a/src/snowflake/snowpark/mock/_telemetry.py b/src/snowflake/snowpark/mock/_telemetry.py
index 857291b47fd..6e4273aa7ff 100644
--- a/src/snowflake/snowpark/mock/_telemetry.py
+++ b/src/snowflake/snowpark/mock/_telemetry.py
@@ -5,6 +5,7 @@
import json
import logging
import os
+import threading
import uuid
from datetime import datetime
from enum import Enum
@@ -92,6 +93,7 @@ def __init__(self) -> None:
)
self._deployment_url = self.PROD
self._enable = True
+ self._lock = threading.RLock()
def _upload_payload(self, payload) -> None:
if not REQUESTS_AVAILABLE:
@@ -136,12 +138,25 @@ def add(self, event) -> None:
if not self.enabled:
return
- self.queue.put(event)
- if self.queue.qsize() > self.batch_size:
- payload = self.export_queue_to_string()
- if payload is None:
- return
- self._upload_payload(payload)
+ with self._lock:
+ self.queue.put(event)
+ if self.queue.qsize() > self.batch_size:
+ payload = self.export_queue_to_string()
+ if payload is None:
+ return
+ self._upload_payload(payload)
+
+ def flush(self) -> None:
+ """Flushes all telemetry events in the queue and submit them to the back-end."""
+ if not self.enabled:
+ return
+
+ with self._lock:
+ if not self.queue.empty():
+ payload = self.export_queue_to_string()
+ if payload is None:
+ return
+ self._upload_payload(payload)
@property
def enabled(self) -> bool:
@@ -158,8 +173,9 @@ def disable(self) -> None:
def export_queue_to_string(self):
logs = list()
- while not self.queue.empty():
- logs.append(self.queue.get())
+ with self._lock:
+ while not self.queue.empty():
+ logs.append(self.queue.get())
# We may get an exception trying to serialize a python object to JSON
try:
payload = json.dumps(logs)
diff --git a/src/snowflake/snowpark/mock/_udf.py b/src/snowflake/snowpark/mock/_udf.py
index 7cedf0de660..a7a17d9a030 100644
--- a/src/snowflake/snowpark/mock/_udf.py
+++ b/src/snowflake/snowpark/mock/_udf.py
@@ -38,9 +38,11 @@ def __init__(self, *args, **kwargs) -> None:
dict()
) # maps udf name to either the callable or a pair of str (module_name, callable_name)
self._session_level_imports = set()
+ self._lock = self._session._conn.get_lock()
def _clear_session_imports(self):
- self._session_level_imports.clear()
+ with self._lock:
+ self._session_level_imports.clear()
def _import_file(
self,
@@ -54,29 +56,32 @@ def _import_file(
When udf_name is not None, the import is added to the UDF associated with the name;
Otherwise, it is a session level import and will be used if no UDF-level imports are specified.
"""
- absolute_module_path, module_name = extract_import_dir_and_module_name(
- file_path, self._session._conn.stage_registry, import_path
- )
- if udf_name:
- self._registry[udf_name].add_import(absolute_module_path)
- else:
- self._session_level_imports.add(absolute_module_path)
+ with self._lock:
+ absolute_module_path, module_name = extract_import_dir_and_module_name(
+ file_path, self._session._conn.stage_registry, import_path
+ )
+ if udf_name:
+ self._registry[udf_name].add_import(absolute_module_path)
+ else:
+ self._session_level_imports.add(absolute_module_path)
- return module_name
+ return module_name
def get_udf(self, udf_name: str) -> MockUserDefinedFunction:
- if udf_name not in self._registry:
- raise SnowparkLocalTestingException(f"udf {udf_name} does not exist.")
- return self._registry[udf_name]
+ with self._lock:
+ if udf_name not in self._registry:
+ raise SnowparkLocalTestingException(f"udf {udf_name} does not exist.")
+ return self._registry[udf_name]
def get_udf_imports(self, udf_name: str) -> Set[str]:
- udf = self._registry.get(udf_name)
- if not udf:
- return set()
- elif udf.use_session_imports:
- return self._session_level_imports
- else:
- return udf._imports
+ with self._lock:
+ udf = self._registry.get(udf_name)
+ if not udf:
+ return set()
+ elif udf.use_session_imports:
+ return self._session_level_imports
+ else:
+ return udf._imports
def _do_register_udf(
self,
@@ -113,73 +118,81 @@ def _do_register_udf(
raise_error=NotImplementedError,
)
- # get the udf name, return and input types
- (
- udf_name,
- is_pandas_udf,
- is_dataframe_input,
- return_type,
- input_types,
- opt_arg_defaults,
- ) = process_registration_inputs(
- self._session, TempObjectType.FUNCTION, func, return_type, input_types, name
- )
-
- current_schema = self._session.get_current_schema()
- current_database = self._session.get_current_database()
- udf_name = get_fully_qualified_name(udf_name, current_schema, current_database)
-
- # allow registering pandas UDF from udf(),
- # but not allow registering non-pandas UDF from pandas_udf()
- if from_pandas_udf_function and not is_pandas_udf:
- raise ValueError(
- "You cannot create a non-vectorized UDF using pandas_udf(). "
- "Use udf() instead."
+ with self._lock:
+ # get the udf name, return and input types
+ (
+ udf_name,
+ is_pandas_udf,
+ is_dataframe_input,
+ return_type,
+ input_types,
+ opt_arg_defaults,
+ ) = process_registration_inputs(
+ self._session,
+ TempObjectType.FUNCTION,
+ func,
+ return_type,
+ input_types,
+ name,
)
- custom_python_runtime_version_allowed = False
+ current_schema = self._session.get_current_schema()
+ current_database = self._session.get_current_database()
+ udf_name = get_fully_qualified_name(
+ udf_name, current_schema, current_database
+ )
- if not custom_python_runtime_version_allowed:
- check_python_runtime_version(
- self._session._runtime_version_from_requirement
+ # allow registering pandas UDF from udf(),
+ # but not allow registering non-pandas UDF from pandas_udf()
+ if from_pandas_udf_function and not is_pandas_udf:
+ raise ValueError(
+ "You cannot create a non-vectorized UDF using pandas_udf(). "
+ "Use udf() instead."
+ )
+
+ custom_python_runtime_version_allowed = False
+
+ if not custom_python_runtime_version_allowed:
+ check_python_runtime_version(
+ self._session._runtime_version_from_requirement
+ )
+
+ if replace and if_not_exists:
+ raise ValueError("options replace and if_not_exists are incompatible")
+
+ if udf_name in self._registry and if_not_exists:
+ return self._registry[udf_name]
+
+ if udf_name in self._registry and not replace:
+ raise SnowparkSQLException(
+ f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.",
+ error_code="1304",
+ )
+
+ if packages:
+ pass # NO-OP
+
+ # register
+ self._registry[udf_name] = MockUserDefinedFunction(
+ func,
+ return_type,
+ input_types,
+ udf_name,
+ strict=strict,
+ packages=packages,
+ use_session_imports=imports is None,
)
- if replace and if_not_exists:
- raise ValueError("options replace and if_not_exists are incompatible")
+ if type(func) is tuple: # update file registration
+ module_name = self._import_file(func[0], udf_name=udf_name)
+ self._registry[udf_name].func = (module_name, func[1])
- if udf_name in self._registry and if_not_exists:
- return self._registry[udf_name]
+ if imports is not None:
+ for _import in imports:
+ if type(_import) is str:
+ self._import_file(_import, udf_name=udf_name)
+ else:
+ local_path, import_path = _import
+ self._import_file(local_path, import_path, udf_name=udf_name)
- if udf_name in self._registry and not replace:
- raise SnowparkSQLException(
- f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.",
- error_code="1304",
- )
-
- if packages:
- pass # NO-OP
-
- # register
- self._registry[udf_name] = MockUserDefinedFunction(
- func,
- return_type,
- input_types,
- udf_name,
- strict=strict,
- packages=packages,
- use_session_imports=imports is None,
- )
-
- if type(func) is tuple: # update file registration
- module_name = self._import_file(func[0], udf_name=udf_name)
- self._registry[udf_name].func = (module_name, func[1])
-
- if imports is not None:
- for _import in imports:
- if type(_import) is str:
- self._import_file(_import, udf_name=udf_name)
- else:
- local_path, import_path = _import
- self._import_file(local_path, import_path, udf_name=udf_name)
-
- return self._registry[udf_name]
+ return self._registry[udf_name]
diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py
index 6960d0eb629..8f9834630b7 100644
--- a/src/snowflake/snowpark/modin/pandas/__init__.py
+++ b/src/snowflake/snowpark/modin/pandas/__init__.py
@@ -88,13 +88,13 @@
# TODO: SNOW-851745 make sure add all Snowpark pandas API general functions
from modin.pandas import plotting # type: ignore[import]
+from modin.pandas.dataframe import DataFrame
from modin.pandas.series import Series
from snowflake.snowpark.modin.pandas.api.extensions import (
register_dataframe_accessor,
register_series_accessor,
)
-from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.general import (
bdate_range,
concat,
@@ -185,10 +185,8 @@
modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP)
-# For any method defined on Series/DF, add telemetry to it if it:
-# 1. Is defined directly on an upstream class
-# 2. The method name does not start with an _, or is in TELEMETRY_PRIVATE_METHODS
-
+# For any method defined on Series/DF, add telemetry to it if the method name does not start with an
+# _, or the method is in TELEMETRY_PRIVATE_METHODS. This includes methods defined as an extension/override.
for attr_name in dir(Series):
# Since Series is defined in upstream Modin, all of its members were either defined upstream
# or overridden by extension.
@@ -197,11 +195,9 @@
try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name))
)
-
-# TODO: SNOW-1063346
-# Since we still use the vendored version of DataFrame and the overrides for the top-level
-# namespace haven't been performed yet, we need to set properties on the vendored version
for attr_name in dir(DataFrame):
+ # Since DataFrame is defined in upstream Modin, all of its members were either defined upstream
+ # or overridden by extension.
if not attr_name.startswith("_") or attr_name in TELEMETRY_PRIVATE_METHODS:
register_dataframe_accessor(attr_name)(
try_add_telemetry_to_attribute(attr_name, getattr(DataFrame, attr_name))
diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py
index 6a34f50e42a..47d44835fe4 100644
--- a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py
+++ b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py
@@ -19,9 +19,12 @@
# existing code originally distributed by the Modin project, under the Apache License,
# Version 2.0.
-from modin.pandas.api.extensions import register_series_accessor
+from modin.pandas.api.extensions import (
+ register_dataframe_accessor,
+ register_series_accessor,
+)
-from .extensions import register_dataframe_accessor, register_pd_accessor
+from .extensions import register_pd_accessor
__all__ = [
"register_dataframe_accessor",
diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py
index 45896292e74..05424c92072 100644
--- a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py
+++ b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py
@@ -86,49 +86,6 @@ def decorator(new_attr: Any):
return decorator
-def register_dataframe_accessor(name: str):
- """
- Registers a dataframe attribute with the name provided.
- This is a decorator that assigns a new attribute to DataFrame. It can be used
- with the following syntax:
- ```
- @register_dataframe_accessor("new_method")
- def my_new_dataframe_method(*args, **kwargs):
- # logic goes here
- return
- ```
- The new attribute can then be accessed with the name provided:
- ```
- df.new_method(*my_args, **my_kwargs)
- ```
-
- If you want a property accessor, you must annotate with @property
- after the call to this function:
- ```
- @register_dataframe_accessor("new_prop")
- @property
- def my_new_dataframe_property(*args, **kwargs):
- return _prop
- ```
-
- Parameters
- ----------
- name : str
- The name of the attribute to assign to DataFrame.
- Returns
- -------
- decorator
- Returns the decorator function.
- """
- import snowflake.snowpark.modin.pandas as pd
-
- return _set_attribute_on_obj(
- name,
- pd.dataframe._DATAFRAME_EXTENSIONS_,
- pd.dataframe.DataFrame,
- )
-
-
def register_pd_accessor(name: str):
"""
Registers a pd namespace attribute with the name provided.
diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py
deleted file mode 100644
index 83893e83e9c..00000000000
--- a/src/snowflake/snowpark/modin/pandas/dataframe.py
+++ /dev/null
@@ -1,3511 +0,0 @@
-#
-# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
-#
-
-# Licensed to Modin Development Team under one or more contributor license agreements.
-# See the NOTICE file distributed with this work for additional information regarding
-# copyright ownership. The Modin Development Team licenses this file to you under the
-# Apache License, Version 2.0 (the "License"); you may not use this file except in
-# compliance with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software distributed under
-# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific language
-# governing permissions and limitations under the License.
-
-# Code in this file may constitute partial or total reimplementation, or modification of
-# existing code originally distributed by the Modin project, under the Apache License,
-# Version 2.0.
-
-"""Module houses ``DataFrame`` class, that is distributed version of ``pandas.DataFrame``."""
-
-from __future__ import annotations
-
-import collections
-import datetime
-import functools
-import itertools
-import re
-import sys
-import warnings
-from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
-from logging import getLogger
-from typing import IO, Any, Callable, Literal
-
-import numpy as np
-import pandas
-from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor
-from modin.pandas.base import BasePandasDataset
-
-# from . import _update_engine
-from modin.pandas.iterator import PartitionIterator
-from modin.pandas.series import Series
-from pandas._libs.lib import NoDefault, no_default
-from pandas._typing import (
- AggFuncType,
- AnyArrayLike,
- Axes,
- Axis,
- CompressionOptions,
- FilePath,
- FillnaOptions,
- IgnoreRaise,
- IndexLabel,
- Level,
- PythonFuncType,
- Renamer,
- Scalar,
- StorageOptions,
- Suffixes,
- WriteBuffer,
-)
-from pandas.core.common import apply_if_callable, is_bool_indexer
-from pandas.core.dtypes.common import (
- infer_dtype_from_object,
- is_bool_dtype,
- is_dict_like,
- is_list_like,
- is_numeric_dtype,
-)
-from pandas.core.dtypes.inference import is_hashable, is_integer
-from pandas.core.indexes.frozen import FrozenList
-from pandas.io.formats.printing import pprint_thing
-from pandas.util._validators import validate_bool_kwarg
-
-from snowflake.snowpark.modin import pandas as pd
-from snowflake.snowpark.modin.pandas.groupby import (
- DataFrameGroupBy,
- validate_groupby_args,
-)
-from snowflake.snowpark.modin.pandas.snow_partition_iterator import (
- SnowparkPandasRowPartitionIterator,
-)
-from snowflake.snowpark.modin.pandas.utils import (
- create_empty_native_pandas_frame,
- from_non_pandas,
- from_pandas,
- is_scalar,
- raise_if_native_pandas_objects,
- replace_external_data_keys_with_empty_pandas_series,
- replace_external_data_keys_with_query_compiler,
-)
-from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated
-from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
-from snowflake.snowpark.modin.plugin.utils.error_message import (
- ErrorMessage,
- dataframe_not_implemented,
-)
-from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP
-from snowflake.snowpark.modin.plugin.utils.warning_message import (
- SET_DATAFRAME_ATTRIBUTE_WARNING,
- WarningMessage,
-)
-from snowflake.snowpark.modin.utils import _inherit_docstrings, hashable, to_pandas
-from snowflake.snowpark.udf import UserDefinedFunction
-
-logger = getLogger(__name__)
-
-DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = (
- "Currently do not support Series or list-like keys with range-like values"
-)
-
-DF_SETITEM_SLICE_AS_SCALAR_VALUE = (
- "Currently do not support assigning a slice value as if it's a scalar value"
-)
-
-DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = (
- "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark "
- "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which "
- "can work on the entire DataFrame in one shot."
-)
-
-# Dictionary of extensions assigned to this class
-_DATAFRAME_EXTENSIONS_ = {}
-
-
-@_inherit_docstrings(
- pandas.DataFrame,
- excluded=[
- pandas.DataFrame.flags,
- pandas.DataFrame.cov,
- pandas.DataFrame.merge,
- pandas.DataFrame.reindex,
- pandas.DataFrame.to_parquet,
- pandas.DataFrame.fillna,
- ],
- apilink="pandas.DataFrame",
-)
-class DataFrame(BasePandasDataset):
- _pandas_class = pandas.DataFrame
-
- def __init__(
- self,
- data=None,
- index=None,
- columns=None,
- dtype=None,
- copy=None,
- query_compiler=None,
- ) -> None:
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # Siblings are other dataframes that share the same query compiler. We
- # use this list to update inplace when there is a shallow copy.
- from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
-
- self._siblings = []
-
- # Engine.subscribe(_update_engine)
- if isinstance(data, (DataFrame, Series)):
- self._query_compiler = data._query_compiler.copy()
- if index is not None and any(i not in data.index for i in index):
- ErrorMessage.not_implemented(
- "Passing non-existant columns or index values to constructor not"
- + " yet implemented."
- ) # pragma: no cover
- if isinstance(data, Series):
- # We set the column name if it is not in the provided Series
- if data.name is None:
- self.columns = [0] if columns is None else columns
- # If the columns provided are not in the named Series, pandas clears
- # the DataFrame and sets columns to the columns provided.
- elif columns is not None and data.name not in columns:
- self._query_compiler = from_pandas(
- self.__constructor__(columns=columns)
- )._query_compiler
- if index is not None:
- self._query_compiler = data.loc[index]._query_compiler
- elif columns is None and index is None:
- data._add_sibling(self)
- else:
- if columns is not None and any(i not in data.columns for i in columns):
- ErrorMessage.not_implemented(
- "Passing non-existant columns or index values to constructor not"
- + " yet implemented."
- ) # pragma: no cover
- if index is None:
- index = slice(None)
- if columns is None:
- columns = slice(None)
- self._query_compiler = data.loc[index, columns]._query_compiler
-
- # Check type of data and use appropriate constructor
- elif query_compiler is None:
- distributed_frame = from_non_pandas(data, index, columns, dtype)
- if distributed_frame is not None:
- self._query_compiler = distributed_frame._query_compiler
- return
-
- if isinstance(data, pandas.Index):
- pass
- elif is_list_like(data) and not is_dict_like(data):
- old_dtype = getattr(data, "dtype", None)
- values = [
- obj._to_pandas() if isinstance(obj, Series) else obj for obj in data
- ]
- if isinstance(data, np.ndarray):
- data = np.array(values, dtype=old_dtype)
- else:
- try:
- data = type(data)(values, dtype=old_dtype)
- except TypeError:
- data = values
- elif is_dict_like(data) and not isinstance(
- data, (pandas.Series, Series, pandas.DataFrame, DataFrame)
- ):
- if columns is not None:
- data = {key: value for key, value in data.items() if key in columns}
-
- if len(data) and all(isinstance(v, Series) for v in data.values()):
- from .general import concat
-
- new_qc = concat(
- data.values(), axis=1, keys=data.keys()
- )._query_compiler
-
- if dtype is not None:
- new_qc = new_qc.astype({col: dtype for col in new_qc.columns})
- if index is not None:
- new_qc = new_qc.reindex(
- axis=0, labels=try_convert_index_to_native(index)
- )
- if columns is not None:
- new_qc = new_qc.reindex(
- axis=1, labels=try_convert_index_to_native(columns)
- )
-
- self._query_compiler = new_qc
- return
-
- data = {
- k: v._to_pandas() if isinstance(v, Series) else v
- for k, v in data.items()
- }
- pandas_df = pandas.DataFrame(
- data=try_convert_index_to_native(data),
- index=try_convert_index_to_native(index),
- columns=try_convert_index_to_native(columns),
- dtype=dtype,
- copy=copy,
- )
- self._query_compiler = from_pandas(pandas_df)._query_compiler
- else:
- self._query_compiler = query_compiler
-
- def __repr__(self):
- """
- Return a string representation for a particular ``DataFrame``.
-
- Returns
- -------
- str
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- num_rows = pandas.get_option("display.max_rows") or len(self)
- # see _repr_html_ for comment, allow here also all column behavior
- num_cols = pandas.get_option("display.max_columns") or len(self.columns)
-
- (
- row_count,
- col_count,
- repr_df,
- ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x")
- result = repr(repr_df)
-
- # if truncated, add shape information
- if is_repr_truncated(row_count, col_count, num_rows, num_cols):
- # The split here is so that we don't repr pandas row lengths.
- return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format(
- row_count, col_count
- )
- else:
- return result
-
- def _repr_html_(self): # pragma: no cover
- """
- Return a html representation for a particular ``DataFrame``.
-
- Returns
- -------
- str
-
- Notes
- -----
- Supports pandas `display.max_rows` and `display.max_columns` options.
- """
- num_rows = pandas.get_option("display.max_rows") or 60
- # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow
- # here value=0 which means display all columns.
- num_cols = pandas.get_option("display.max_columns")
-
- (
- row_count,
- col_count,
- repr_df,
- ) = self._query_compiler.build_repr_df(num_rows, num_cols)
- result = repr_df._repr_html_()
-
- if is_repr_truncated(row_count, col_count, num_rows, num_cols):
- # We split so that we insert our correct dataframe dimensions.
- return (
- result.split("
")[0]
- + f"
{row_count} rows × {col_count} columns
\n"
- )
- else:
- return result
-
- def _get_columns(self) -> pandas.Index:
- """
- Get the columns for this Snowpark pandas ``DataFrame``.
-
- Returns
- -------
- Index
- The all columns.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._query_compiler.columns
-
- def _set_columns(self, new_columns: Axes) -> None:
- """
- Set the columns for this Snowpark pandas ``DataFrame``.
-
- Parameters
- ----------
- new_columns :
- The new columns to set.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- self._update_inplace(
- new_query_compiler=self._query_compiler.set_columns(new_columns)
- )
-
- columns = property(_get_columns, _set_columns)
-
- @property
- def ndim(self) -> int:
- return 2
-
- def drop_duplicates(
- self, subset=None, keep="first", inplace=False, ignore_index=False
- ): # noqa: PR01, RT01, D200
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- """
- Return ``DataFrame`` with duplicate rows removed.
- """
- return super().drop_duplicates(
- subset=subset, keep=keep, inplace=inplace, ignore_index=ignore_index
- )
-
- def dropna(
- self,
- *,
- axis: Axis = 0,
- how: str | NoDefault = no_default,
- thresh: int | NoDefault = no_default,
- subset: IndexLabel = None,
- inplace: bool = False,
- ): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super()._dropna(
- axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace
- )
-
- @property
- def dtypes(self): # noqa: RT01, D200
- """
- Return the dtypes in the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._query_compiler.dtypes
-
- def duplicated(
- self, subset: Hashable | Sequence[Hashable] = None, keep: DropKeep = "first"
- ):
- """
- Return boolean ``Series`` denoting duplicate rows.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- df = self[subset] if subset is not None else self
- new_qc = df._query_compiler.duplicated(keep=keep)
- duplicates = self._reduce_dimension(new_qc)
- # remove Series name which was assigned automatically by .apply in QC
- # this is pandas behavior, i.e., if duplicated result is a series, no name is returned
- duplicates.name = None
- return duplicates
-
- @property
- def empty(self) -> bool:
- """
- Indicate whether ``DataFrame`` is empty.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return len(self.columns) == 0 or len(self) == 0
-
- @property
- def axes(self):
- """
- Return a list representing the axes of the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return [self.index, self.columns]
-
- @property
- def shape(self) -> tuple[int, int]:
- """
- Return a tuple representing the dimensionality of the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return len(self), len(self.columns)
-
- def add_prefix(self, prefix):
- """
- Prefix labels with string `prefix`.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # pandas converts non-string prefix values into str and adds it to the column labels.
- return self.__constructor__(
- query_compiler=self._query_compiler.add_substring(
- str(prefix), substring_type="prefix", axis=1
- )
- )
-
- def add_suffix(self, suffix):
- """
- Suffix labels with string `suffix`.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # pandas converts non-string suffix values into str and appends it to the column labels.
- return self.__constructor__(
- query_compiler=self._query_compiler.add_substring(
- str(suffix), substring_type="suffix", axis=1
- )
- )
-
- @dataframe_not_implemented()
- def map(
- self, func, na_action: str | None = None, **kwargs
- ) -> DataFrame: # pragma: no cover
- if not callable(func):
- raise ValueError(f"'{type(func)}' object is not callable")
- return self.__constructor__(
- query_compiler=self._query_compiler.map(func, na_action=na_action, **kwargs)
- )
-
- def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if not callable(func):
- raise TypeError(f"{func} is not callable")
- return self.__constructor__(
- query_compiler=self._query_compiler.applymap(
- func, na_action=na_action, **kwargs
- )
- )
-
- def aggregate(
- self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().aggregate(func, axis, *args, **kwargs)
-
- agg = aggregate
-
- def apply(
- self,
- func: AggFuncType | UserDefinedFunction,
- axis: Axis = 0,
- raw: bool = False,
- result_type: Literal["expand", "reduce", "broadcast"] | None = None,
- args=(),
- **kwargs,
- ):
- """
- Apply a function along an axis of the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- axis = self._get_axis_number(axis)
- query_compiler = self._query_compiler.apply(
- func,
- axis,
- raw=raw,
- result_type=result_type,
- args=args,
- **kwargs,
- )
- if not isinstance(query_compiler, type(self._query_compiler)):
- # A scalar was returned
- return query_compiler
-
- # If True, it is an unamed series.
- # Theoretically, if df.apply returns a Series, it will only be an unnamed series
- # because the function is supposed to be series -> scalar.
- if query_compiler._modin_frame.is_unnamed_series():
- return Series(query_compiler=query_compiler)
- else:
- return self.__constructor__(query_compiler=query_compiler)
-
- def groupby(
- self,
- by=None,
- axis: Axis | NoDefault = no_default,
- level: IndexLabel | None = None,
- as_index: bool = True,
- sort: bool = True,
- group_keys: bool = True,
- observed: bool | NoDefault = no_default,
- dropna: bool = True,
- ):
- """
- Group ``DataFrame`` using a mapper or by a ``Series`` of columns.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if axis is not no_default:
- axis = self._get_axis_number(axis)
- if axis == 1:
- warnings.warn(
- "DataFrame.groupby with axis=1 is deprecated. Do "
- + "`frame.T.groupby(...)` without axis instead.",
- FutureWarning,
- stacklevel=1,
- )
- else:
- warnings.warn(
- "The 'axis' keyword in DataFrame.groupby is deprecated and "
- + "will be removed in a future version.",
- FutureWarning,
- stacklevel=1,
- )
- else:
- axis = 0
-
- validate_groupby_args(by, level, observed)
-
- axis = self._get_axis_number(axis)
-
- if axis != 0 and as_index is False:
- raise ValueError("as_index=False only valid for axis=0")
-
- idx_name = None
-
- if (
- not isinstance(by, Series)
- and is_list_like(by)
- and len(by) == 1
- # if by is a list-like of (None,), we have to keep it as a list because
- # None may be referencing a column or index level whose label is
- # `None`, and by=None wold mean that there is no `by` param.
- and by[0] is not None
- ):
- by = by[0]
-
- if hashable(by) and (
- not callable(by) and not isinstance(by, (pandas.Grouper, FrozenList))
- ):
- idx_name = by
- elif isinstance(by, Series):
- idx_name = by.name
- if by._parent is self:
- # if the SnowSeries comes from the current dataframe,
- # convert it to labels directly for easy processing
- by = by.name
- elif is_list_like(by):
- if axis == 0 and all(
- (
- (hashable(o) and (o in self))
- or isinstance(o, Series)
- or (is_list_like(o) and len(o) == len(self.shape[axis]))
- )
- for o in by
- ):
- # plit 'by's into those that belongs to the self (internal_by)
- # and those that doesn't (external_by). For SnowSeries that belongs
- # to current DataFrame, we convert it to labels for easy process.
- internal_by, external_by = [], []
-
- for current_by in by:
- if hashable(current_by):
- internal_by.append(current_by)
- elif isinstance(current_by, Series):
- if current_by._parent is self:
- internal_by.append(current_by.name)
- else:
- external_by.append(current_by) # pragma: no cover
- else:
- external_by.append(current_by)
-
- by = internal_by + external_by
-
- return DataFrameGroupBy(
- self,
- by,
- axis,
- level,
- as_index,
- sort,
- group_keys,
- idx_name,
- observed=observed,
- dropna=dropna,
- )
-
- def keys(self): # noqa: RT01, D200
- """
- Get columns of the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self.columns
-
- def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200
- """
- Transpose index and columns.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if copy:
- WarningMessage.ignored_argument(
- operation="transpose",
- argument="copy",
- message="Transpose ignore copy argument in Snowpark pandas API",
- )
-
- if args:
- WarningMessage.ignored_argument(
- operation="transpose",
- argument="args",
- message="Transpose ignores args in Snowpark pandas API",
- )
-
- return self.__constructor__(query_compiler=self._query_compiler.transpose())
-
- T = property(transpose)
-
- def add(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get addition of ``DataFrame`` and `other`, element-wise (binary operator `add`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "add",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def assign(self, **kwargs): # noqa: PR01, RT01, D200
- """
- Assign new columns to a ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
-
- df = self.copy()
- for k, v in kwargs.items():
- if callable(v):
- df[k] = v(df)
- else:
- df[k] = v
- return df
-
- @dataframe_not_implemented()
- def boxplot(
- self,
- column=None,
- by=None,
- ax=None,
- fontsize=None,
- rot=0,
- grid=True,
- figsize=None,
- layout=None,
- return_type=None,
- backend=None,
- **kwargs,
- ): # noqa: PR01, RT01, D200
- """
- Make a box plot from ``DataFrame`` columns.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return to_pandas(self).boxplot(
- column=column,
- by=by,
- ax=ax,
- fontsize=fontsize,
- rot=rot,
- grid=grid,
- figsize=figsize,
- layout=layout,
- return_type=return_type,
- backend=backend,
- **kwargs,
- )
-
- @dataframe_not_implemented()
- def combine(
- self, other, func, fill_value=None, overwrite=True
- ): # noqa: PR01, RT01, D200
- """
- Perform column-wise combine with another ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().combine(other, func, fill_value=fill_value, overwrite=overwrite)
-
- def compare(
- self,
- other,
- align_axis=1,
- keep_shape: bool = False,
- keep_equal: bool = False,
- result_names=("self", "other"),
- ) -> DataFrame: # noqa: PR01, RT01, D200
- """
- Compare to another ``DataFrame`` and show the differences.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if not isinstance(other, DataFrame):
- raise TypeError(f"Cannot compare DataFrame to {type(other)}")
- other = self._validate_other(other, 0, compare_index=True)
- return self.__constructor__(
- query_compiler=self._query_compiler.compare(
- other,
- align_axis=align_axis,
- keep_shape=keep_shape,
- keep_equal=keep_equal,
- result_names=result_names,
- )
- )
-
- def corr(
- self,
- method: str | Callable = "pearson",
- min_periods: int | None = None,
- numeric_only: bool = False,
- ): # noqa: PR01, RT01, D200
- """
- Compute pairwise correlation of columns, excluding NA/null values.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- corr_df = self
- if numeric_only:
- corr_df = self.drop(
- columns=[
- i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i])
- ]
- )
- return self.__constructor__(
- query_compiler=corr_df._query_compiler.corr(
- method=method,
- min_periods=min_periods,
- )
- )
-
- @dataframe_not_implemented()
- def corrwith(
- self, other, axis=0, drop=False, method="pearson", numeric_only=False
- ): # noqa: PR01, RT01, D200
- """
- Compute pairwise correlation.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if isinstance(other, DataFrame):
- other = other._query_compiler.to_pandas()
- return self._default_to_pandas(
- pandas.DataFrame.corrwith,
- other,
- axis=axis,
- drop=drop,
- method=method,
- numeric_only=numeric_only,
- )
-
- @dataframe_not_implemented()
- def cov(
- self,
- min_periods: int | None = None,
- ddof: int | None = 1,
- numeric_only: bool = False,
- ):
- """
- Compute pairwise covariance of columns, excluding NA/null values.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self.__constructor__(
- query_compiler=self._query_compiler.cov(
- min_periods=min_periods,
- ddof=ddof,
- numeric_only=numeric_only,
- )
- )
-
- @dataframe_not_implemented()
- def dot(self, other): # noqa: PR01, RT01, D200
- """
- Compute the matrix multiplication between the ``DataFrame`` and `other`.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
-
- if isinstance(other, BasePandasDataset):
- common = self.columns.union(other.index)
- if len(common) > len(self.columns) or len(common) > len(
- other
- ): # pragma: no cover
- raise ValueError("Matrices are not aligned")
-
- if isinstance(other, DataFrame):
- return self.__constructor__(
- query_compiler=self._query_compiler.dot(
- other.reindex(index=common), squeeze_self=False
- )
- )
- else:
- return self._reduce_dimension(
- query_compiler=self._query_compiler.dot(
- other.reindex(index=common), squeeze_self=False
- )
- )
-
- other = np.asarray(other)
- if self.shape[1] != other.shape[0]:
- raise ValueError(
- f"Dot product shape mismatch, {self.shape} vs {other.shape}"
- )
-
- if len(other.shape) > 1:
- return self.__constructor__(
- query_compiler=self._query_compiler.dot(other, squeeze_self=False)
- )
-
- return self._reduce_dimension(
- query_compiler=self._query_compiler.dot(other, squeeze_self=False)
- )
-
- def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
- """
- Perform equality comparison of ``DataFrame`` and `other` (binary operator `eq`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("eq", other, axis=axis, level=level)
-
- def equals(self, other) -> bool: # noqa: PR01, RT01, D200
- """
- Test whether two objects contain the same elements.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if isinstance(other, pandas.DataFrame):
- # Copy into a Modin DataFrame to simplify logic below
- other = self.__constructor__(other)
-
- if (
- type(self) is not type(other)
- or not self.index.equals(other.index)
- or not self.columns.equals(other.columns)
- ):
- return False
-
- result = self.__constructor__(
- query_compiler=self._query_compiler.equals(other._query_compiler)
- )
- return result.all(axis=None)
-
- def _update_var_dicts_in_kwargs(self, expr, kwargs):
- """
- Copy variables with "@" prefix in `local_dict` and `global_dict` keys of kwargs.
-
- Parameters
- ----------
- expr : str
- The expression string to search variables with "@" prefix.
- kwargs : dict
- See the documentation for eval() for complete details on the keyword arguments accepted by query().
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if "@" not in expr:
- return
- frame = sys._getframe()
- try:
- f_locals = frame.f_back.f_back.f_back.f_back.f_locals
- f_globals = frame.f_back.f_back.f_back.f_back.f_globals
- finally:
- del frame
- local_names = set(re.findall(r"@([\w]+)", expr))
- local_dict = {}
- global_dict = {}
-
- for name in local_names:
- for dct_out, dct_in in ((local_dict, f_locals), (global_dict, f_globals)):
- try:
- dct_out[name] = dct_in[name]
- except KeyError:
- pass
-
- if local_dict:
- local_dict.update(kwargs.get("local_dict") or {})
- kwargs["local_dict"] = local_dict
- if global_dict:
- global_dict.update(kwargs.get("global_dict") or {})
- kwargs["global_dict"] = global_dict
-
- @dataframe_not_implemented()
- def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200
- """
- Evaluate a string describing operations on ``DataFrame`` columns.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- self._validate_eval_query(expr, **kwargs)
- inplace = validate_bool_kwarg(inplace, "inplace")
- self._update_var_dicts_in_kwargs(expr, kwargs)
- new_query_compiler = self._query_compiler.eval(expr, **kwargs)
- return_type = type(
- pandas.DataFrame(columns=self.columns)
- .astype(self.dtypes)
- .eval(expr, **kwargs)
- ).__name__
- if return_type == type(self).__name__:
- return self._create_or_update_from_compiler(new_query_compiler, inplace)
- else:
- if inplace:
- raise ValueError("Cannot operate inplace if there is no assignment")
- return getattr(sys.modules[self.__module__], return_type)(
- query_compiler=new_query_compiler
- )
-
- def fillna(
- self,
- value: Hashable | Mapping | Series | DataFrame = None,
- *,
- method: FillnaOptions | None = None,
- axis: Axis | None = None,
- inplace: bool = False,
- limit: int | None = None,
- downcast: dict | None = None,
- ) -> DataFrame | None:
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().fillna(
- self_is_series=False,
- value=value,
- method=method,
- axis=axis,
- inplace=inplace,
- limit=limit,
- downcast=downcast,
- )
-
- def floordiv(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `floordiv`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "floordiv",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- @classmethod
- @dataframe_not_implemented()
- def from_dict(
- cls, data, orient="columns", dtype=None, columns=None
- ): # pragma: no cover # noqa: PR01, RT01, D200
- """
- Construct ``DataFrame`` from dict of array-like or dicts.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return from_pandas(
- pandas.DataFrame.from_dict(
- data, orient=orient, dtype=dtype, columns=columns
- )
- )
-
- @classmethod
- @dataframe_not_implemented()
- def from_records(
- cls,
- data,
- index=None,
- exclude=None,
- columns=None,
- coerce_float=False,
- nrows=None,
- ): # pragma: no cover # noqa: PR01, RT01, D200
- """
- Convert structured or record ndarray to ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return from_pandas(
- pandas.DataFrame.from_records(
- data,
- index=index,
- exclude=exclude,
- columns=columns,
- coerce_float=coerce_float,
- nrows=nrows,
- )
- )
-
- def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
- """
- Get greater than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("ge", other, axis=axis, level=level)
-
- def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
- """
- Get greater than comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("gt", other, axis=axis, level=level)
-
- @dataframe_not_implemented()
- def hist(
- self,
- column=None,
- by=None,
- grid=True,
- xlabelsize=None,
- xrot=None,
- ylabelsize=None,
- yrot=None,
- ax=None,
- sharex=False,
- sharey=False,
- figsize=None,
- layout=None,
- bins=10,
- **kwds,
- ): # pragma: no cover # noqa: PR01, RT01, D200
- """
- Make a histogram of the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.hist,
- column=column,
- by=by,
- grid=grid,
- xlabelsize=xlabelsize,
- xrot=xrot,
- ylabelsize=ylabelsize,
- yrot=yrot,
- ax=ax,
- sharex=sharex,
- sharey=sharey,
- figsize=figsize,
- layout=layout,
- bins=bins,
- **kwds,
- )
-
- def info(
- self,
- verbose: bool | None = None,
- buf: IO[str] | None = None,
- max_cols: int | None = None,
- memory_usage: bool | str | None = None,
- show_counts: bool | None = None,
- null_counts: bool | None = None,
- ): # noqa: PR01, D200
- """
- Print a concise summary of the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- def put_str(src, output_len=None, spaces=2):
- src = str(src)
- return src.ljust(output_len if output_len else len(src)) + " " * spaces
-
- def format_size(num):
- for x in ["bytes", "KB", "MB", "GB", "TB"]:
- if num < 1024.0:
- return f"{num:3.1f} {x}"
- num /= 1024.0
- return f"{num:3.1f} PB"
-
- output = []
-
- type_line = str(type(self))
- index_line = "SnowflakeIndex"
- columns = self.columns
- columns_len = len(columns)
- dtypes = self.dtypes
- dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}"
-
- if max_cols is None:
- max_cols = 100
-
- exceeds_info_cols = columns_len > max_cols
-
- if buf is None:
- buf = sys.stdout
-
- if null_counts is None:
- null_counts = not exceeds_info_cols
-
- if verbose is None:
- verbose = not exceeds_info_cols
-
- if null_counts and verbose:
- # We're gonna take items from `non_null_count` in a loop, which
- # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here
- # that will be faster.
- non_null_count = self.count()._to_pandas()
-
- if memory_usage is None:
- memory_usage = True
-
- def get_header(spaces=2):
- output = []
- head_label = " # "
- column_label = "Column"
- null_label = "Non-Null Count"
- dtype_label = "Dtype"
- non_null_label = " non-null"
- delimiter = "-"
-
- lengths = {}
- lengths["head"] = max(len(head_label), len(pprint_thing(len(columns))))
- lengths["column"] = max(
- len(column_label), max(len(pprint_thing(col)) for col in columns)
- )
- lengths["dtype"] = len(dtype_label)
- dtype_spaces = (
- max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes))
- - lengths["dtype"]
- )
-
- header = put_str(head_label, lengths["head"]) + put_str(
- column_label, lengths["column"]
- )
- if null_counts:
- lengths["null"] = max(
- len(null_label),
- max(len(pprint_thing(x)) for x in non_null_count)
- + len(non_null_label),
- )
- header += put_str(null_label, lengths["null"])
- header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces)
-
- output.append(header)
-
- delimiters = put_str(delimiter * lengths["head"]) + put_str(
- delimiter * lengths["column"]
- )
- if null_counts:
- delimiters += put_str(delimiter * lengths["null"])
- delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces)
- output.append(delimiters)
-
- return output, lengths
-
- output.extend([type_line, index_line])
-
- def verbose_repr(output):
- columns_line = f"Data columns (total {len(columns)} columns):"
- header, lengths = get_header()
- output.extend([columns_line, *header])
- for i, col in enumerate(columns):
- i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]])
-
- to_append = put_str(f" {i}", lengths["head"]) + put_str(
- col_s, lengths["column"]
- )
- if null_counts:
- non_null = pprint_thing(non_null_count[col])
- to_append += put_str(f"{non_null} non-null", lengths["null"])
- to_append += put_str(dtype, lengths["dtype"], spaces=0)
- output.append(to_append)
-
- def non_verbose_repr(output):
- output.append(columns._summary(name="Columns"))
-
- if verbose:
- verbose_repr(output)
- else:
- non_verbose_repr(output)
-
- output.append(dtypes_line)
-
- if memory_usage:
- deep = memory_usage == "deep"
- mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum()
- mem_line = f"memory usage: {format_size(mem_usage_bytes)}"
-
- output.append(mem_line)
-
- output.append("")
- buf.write("\n".join(output))
-
- def insert(
- self,
- loc: int,
- column: Hashable,
- value: Scalar | AnyArrayLike,
- allow_duplicates: bool | NoDefault = no_default,
- ) -> None:
- """
- Insert column into ``DataFrame`` at specified location.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- raise_if_native_pandas_objects(value)
- if allow_duplicates is no_default:
- allow_duplicates = False
- if not allow_duplicates and column in self.columns:
- raise ValueError(f"cannot insert {column}, already exists")
-
- if not isinstance(loc, int):
- raise TypeError("loc must be int")
-
- # If columns labels are multilevel, we implement following behavior (this is
- # name native pandas):
- # Case 1: if 'column' is tuple it's length must be same as number of levels
- # otherwise raise error.
- # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in
- # empty strings to match the length of column levels in self frame.
- if self.columns.nlevels > 1:
- if isinstance(column, tuple) and len(column) != self.columns.nlevels:
- # same error as native pandas.
- raise ValueError("Item must have length equal to number of levels.")
- if not isinstance(column, tuple):
- # Fill empty strings to match length of levels
- suffix = [""] * (self.columns.nlevels - 1)
- column = tuple([column] + suffix)
-
- # Dictionary keys are treated as index column and this should be joined with
- # index of target dataframe. This behavior is similar to 'value' being DataFrame
- # or Series, so we simply create Series from dict data here.
- if isinstance(value, dict):
- value = Series(value, name=column)
-
- if isinstance(value, DataFrame) or (
- isinstance(value, np.ndarray) and len(value.shape) > 1
- ):
- # Supported numpy array shapes are
- # 1. (N, ) -> Ex. [1, 2, 3]
- # 2. (N, 1) -> Ex> [[1], [2], [3]]
- if value.shape[1] != 1:
- if isinstance(value, DataFrame):
- # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin
- raise ValueError(
- f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead."
- )
- else:
- raise ValueError(
- f"Expected a 1D array, got an array with shape {value.shape}"
- )
- # Change numpy array shape from (N, 1) to (N, )
- if isinstance(value, np.ndarray):
- value = value.squeeze(axis=1)
-
- if (
- is_list_like(value)
- and not isinstance(value, (Series, DataFrame))
- and len(value) != self.shape[0]
- and not 0 == self.shape[0] # dataframe holds no rows
- ):
- raise ValueError(
- "Length of values ({}) does not match length of index ({})".format(
- len(value), len(self)
- )
- )
- if not -len(self.columns) <= loc <= len(self.columns):
- raise IndexError(
- f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}"
- )
- elif loc < 0:
- raise ValueError("unbounded slice")
-
- join_on_index = False
- if isinstance(value, (Series, DataFrame)):
- value = value._query_compiler
- join_on_index = True
- elif is_list_like(value):
- value = Series(value, name=column)._query_compiler
-
- new_query_compiler = self._query_compiler.insert(
- loc, column, value, join_on_index
- )
- # In pandas, 'insert' operation is always inplace.
- self._update_inplace(new_query_compiler=new_query_compiler)
-
- @dataframe_not_implemented()
- def interpolate(
- self,
- method="linear",
- axis=0,
- limit=None,
- inplace=False,
- limit_direction: str | None = None,
- limit_area=None,
- downcast=None,
- **kwargs,
- ): # noqa: PR01, RT01, D200
- """
- Fill NaN values using an interpolation method.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.interpolate,
- method=method,
- axis=axis,
- limit=limit,
- inplace=inplace,
- limit_direction=limit_direction,
- limit_area=limit_area,
- downcast=downcast,
- **kwargs,
- )
-
- def iterrows(self) -> Iterator[tuple[Hashable, Series]]:
- """
- Iterate over ``DataFrame`` rows as (index, ``Series``) pairs.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- def iterrow_builder(s):
- """Return tuple of the given `s` parameter name and the parameter themselves."""
- return s.name, s
-
- # Raise warning message since iterrows is very inefficient.
- WarningMessage.single_warning(
- DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows")
- )
-
- partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder)
- yield from partition_iterator
-
- def items(self): # noqa: D200
- """
- Iterate over (column name, ``Series``) pairs.
- """
-
- def items_builder(s):
- """Return tuple of the given `s` parameter name and the parameter themselves."""
- return s.name, s
-
- partition_iterator = PartitionIterator(self, 1, items_builder)
- yield from partition_iterator
-
- def itertuples(
- self, index: bool = True, name: str | None = "Pandas"
- ) -> Iterable[tuple[Any, ...]]:
- """
- Iterate over ``DataFrame`` rows as ``namedtuple``-s.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
-
- def itertuples_builder(s):
- """Return the next namedtuple."""
- # s is the Series of values in the current row.
- fields = [] # column names
- data = [] # values under each column
-
- if index:
- data.append(s.name)
- fields.append("Index")
-
- # Fill column names and values.
- fields.extend(list(self.columns))
- data.extend(s)
-
- if name is not None:
- # Creating the namedtuple.
- itertuple = collections.namedtuple(name, fields, rename=True)
- return itertuple._make(data)
-
- # When the name is None, return a regular tuple.
- return tuple(data)
-
- # Raise warning message since itertuples is very inefficient.
- WarningMessage.single_warning(
- DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples")
- )
- return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True)
-
- def join(
- self,
- other: DataFrame | Series | Iterable[DataFrame | Series],
- on: IndexLabel | None = None,
- how: str = "left",
- lsuffix: str = "",
- rsuffix: str = "",
- sort: bool = False,
- validate: str | None = None,
- ) -> DataFrame:
- """
- Join columns of another ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- for o in other if isinstance(other, list) else [other]:
- raise_if_native_pandas_objects(o)
-
- # Similar to native pandas we implement 'join' using 'pd.merge' method.
- # Following code is copied from native pandas (with few changes explained below)
- # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002
- if isinstance(other, Series):
- # Same error as native pandas.
- if other.name is None:
- raise ValueError("Other Series must have a name")
- other = DataFrame(other)
- elif is_list_like(other):
- if any([isinstance(o, Series) and o.name is None for o in other]):
- raise ValueError("Other Series must have a name")
-
- if isinstance(other, DataFrame):
- if how == "cross":
- return pd.merge(
- self,
- other,
- how=how,
- on=on,
- suffixes=(lsuffix, rsuffix),
- sort=sort,
- validate=validate,
- )
- return pd.merge(
- self,
- other,
- left_on=on,
- how=how,
- left_index=on is None,
- right_index=True,
- suffixes=(lsuffix, rsuffix),
- sort=sort,
- validate=validate,
- )
- else: # List of DataFrame/Series
- # Same error as native pandas.
- if on is not None:
- raise ValueError(
- "Joining multiple DataFrames only supported for joining on index"
- )
-
- # Same error as native pandas.
- if rsuffix or lsuffix:
- raise ValueError(
- "Suffixes not supported when joining multiple DataFrames"
- )
-
- # NOTE: These are not the differences between Snowpark pandas API and pandas behavior
- # these are differences between native pandas join behavior when join
- # frames have unique index or not.
-
- # In native pandas logic to join multiple DataFrames/Series is data
- # dependent. Under the hood it will either use 'concat' or 'merge' API
- # Case 1. If all objects being joined have unique index use 'concat' (axis=1)
- # Case 2. Otherwise use 'merge' API by looping through objects left to right.
- # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046
-
- # Even though concat (axis=1) and merge are very similar APIs they have
- # some differences which leads to inconsistent behavior in native pandas.
- # 1. Treatment of un-named Series
- # Case #1: Un-named series is allowed in concat API. Objects are joined
- # successfully by assigning a number as columns name (see 'concat' API
- # documentation for details on treatment of un-named series).
- # Case #2: It raises 'ValueError: Other Series must have a name'
-
- # 2. how='right'
- # Case #1: 'concat' API doesn't support right join. It raises
- # 'ValueError: Only can inner (intersect) or outer (union) join the other axis'
- # Case #2: Merges successfully.
-
- # 3. Joining frames with duplicate labels but no conflict with other frames
- # Example: self = DataFrame(... columns=["A", "B"])
- # other = [DataFrame(... columns=["C", "C"])]
- # Case #1: 'ValueError: Indexes have overlapping values'
- # Case #2: Merged successfully.
-
- # In addition to this, native pandas implementation also leads to another
- # type of inconsistency where left.join(other, ...) and
- # left.join([other], ...) might behave differently for cases mentioned
- # above.
- # Example:
- # import pandas as pd
- # df = pd.DataFrame({"a": [4, 5]})
- # other = pd.Series([1, 2])
- # df.join([other]) # this is successful
- # df.join(other) # this raises 'ValueError: Other Series must have a name'
-
- # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API
- # to join multiple DataFrame/Series. So always follow the behavior
- # documented as Case #2 above.
-
- joined = self
- for frame in other:
- if isinstance(frame, DataFrame):
- overlapping_cols = set(joined.columns).intersection(
- set(frame.columns)
- )
- if len(overlapping_cols) > 0:
- # Native pandas raises: 'Indexes have overlapping values'
- # We differ slightly from native pandas message to make it more
- # useful to users.
- raise ValueError(
- f"Join dataframes have overlapping column labels: {overlapping_cols}"
- )
- joined = pd.merge(
- joined,
- frame,
- how=how,
- left_index=True,
- right_index=True,
- validate=validate,
- sort=sort,
- suffixes=(None, None),
- )
- return joined
-
- def isna(self):
- return super().isna()
-
- def isnull(self):
- return super().isnull()
-
- @dataframe_not_implemented()
- def isetitem(self, loc, value):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.isetitem,
- loc=loc,
- value=value,
- )
-
- def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
- """
- Get less than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("le", other, axis=axis, level=level)
-
- def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
- """
- Get less than comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("lt", other, axis=axis, level=level)
-
- def melt(
- self,
- id_vars=None,
- value_vars=None,
- var_name=None,
- value_name="value",
- col_level=None,
- ignore_index=True,
- ): # noqa: PR01, RT01, D200
- """
- Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if id_vars is None:
- id_vars = []
- if not is_list_like(id_vars):
- id_vars = [id_vars]
- if value_vars is None:
- # Behavior of Index.difference changed in 2.2.x
- # https://github.com/pandas-dev/pandas/pull/55113
- # This change needs upstream to Modin:
- # https://github.com/modin-project/modin/issues/7206
- value_vars = self.columns.drop(id_vars)
- if var_name is None:
- columns_name = self._query_compiler.get_index_name(axis=1)
- var_name = columns_name if columns_name is not None else "variable"
- return self.__constructor__(
- query_compiler=self._query_compiler.melt(
- id_vars=id_vars,
- value_vars=value_vars,
- var_name=var_name,
- value_name=value_name,
- col_level=col_level,
- ignore_index=ignore_index,
- )
- )
-
- @dataframe_not_implemented()
- def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200
- """
- Return the memory usage of each column in bytes.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
-
- if index:
- result = self._reduce_dimension(
- self._query_compiler.memory_usage(index=False, deep=deep)
- )
- index_value = self.index.memory_usage(deep=deep)
- return pd.concat(
- [Series(index_value, index=["Index"]), result]
- ) # pragma: no cover
- return super().memory_usage(index=index, deep=deep)
-
- def merge(
- self,
- right: DataFrame | Series,
- how: str = "inner",
- on: IndexLabel | None = None,
- left_on: Hashable
- | AnyArrayLike
- | Sequence[Hashable | AnyArrayLike]
- | None = None,
- right_on: Hashable
- | AnyArrayLike
- | Sequence[Hashable | AnyArrayLike]
- | None = None,
- left_index: bool = False,
- right_index: bool = False,
- sort: bool = False,
- suffixes: Suffixes = ("_x", "_y"),
- copy: bool = True,
- indicator: bool = False,
- validate: str | None = None,
- ) -> DataFrame:
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # Raise error if native pandas objects are passed.
- raise_if_native_pandas_objects(right)
-
- if isinstance(right, Series) and right.name is None:
- raise ValueError("Cannot merge a Series without a name")
- if not isinstance(right, (Series, DataFrame)):
- raise TypeError(
- f"Can only merge Series or DataFrame objects, a {type(right)} was passed"
- )
-
- if isinstance(right, Series):
- right_column_nlevels = (
- len(right.name) if isinstance(right.name, tuple) else 1
- )
- else:
- right_column_nlevels = right.columns.nlevels
- if self.columns.nlevels != right_column_nlevels:
- # This is deprecated in native pandas. We raise explicit error for this.
- raise ValueError(
- "Can not merge objects with different column levels."
- + f" ({self.columns.nlevels} levels on the left,"
- + f" {right_column_nlevels} on the right)"
- )
-
- # Merge empty native pandas dataframes for error checking. Otherwise, it will
- # require a lot of logic to be written. This takes care of raising errors for
- # following scenarios:
- # 1. Only 'left_index' is set to True.
- # 2. Only 'right_index is set to True.
- # 3. Only 'left_on' is provided.
- # 4. Only 'right_on' is provided.
- # 5. 'on' and 'left_on' both are provided
- # 6. 'on' and 'right_on' both are provided
- # 7. 'on' and 'left_index' both are provided
- # 8. 'on' and 'right_index' both are provided
- # 9. 'left_on' and 'left_index' both are provided
- # 10. 'right_on' and 'right_index' both are provided
- # 11. Length mismatch between 'left_on' and 'right_on'
- # 12. 'left_index' is not a bool
- # 13. 'right_index' is not a bool
- # 14. 'on' is not None and how='cross'
- # 15. 'left_on' is not None and how='cross'
- # 16. 'right_on' is not None and how='cross'
- # 17. 'left_index' is True and how='cross'
- # 18. 'right_index' is True and how='cross'
- # 19. Unknown label in 'on', 'left_on' or 'right_on'
- # 20. Provided 'suffixes' is not sufficient to resolve conflicts.
- # 21. Merging on column with duplicate labels.
- # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'}
- # 23. conflict with existing labels for array-like join key
- # 24. 'indicator' argument is not bool or str
- # 25. indicator column label conflicts with existing data labels
- create_empty_native_pandas_frame(self).merge(
- create_empty_native_pandas_frame(right),
- on=on,
- how=how,
- left_on=replace_external_data_keys_with_empty_pandas_series(left_on),
- right_on=replace_external_data_keys_with_empty_pandas_series(right_on),
- left_index=left_index,
- right_index=right_index,
- suffixes=suffixes,
- indicator=indicator,
- )
-
- return self.__constructor__(
- query_compiler=self._query_compiler.merge(
- right._query_compiler,
- how=how,
- on=on,
- left_on=replace_external_data_keys_with_query_compiler(self, left_on),
- right_on=replace_external_data_keys_with_query_compiler(
- right, right_on
- ),
- left_index=left_index,
- right_index=right_index,
- sort=sort,
- suffixes=suffixes,
- copy=copy,
- indicator=indicator,
- validate=validate,
- )
- )
-
- def mod(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `mod`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "mod",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def mul(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "mul",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- multiply = mul
-
- def rmul(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "rmul",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
- """
- Get not equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ne`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("ne", other, axis=axis, level=level)
-
- def nlargest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200
- """
- Return the first `n` rows ordered by `columns` in descending order.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self.__constructor__(
- query_compiler=self._query_compiler.nlargest(n, columns, keep)
- )
-
- def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200
- """
- Return the first `n` rows ordered by `columns` in ascending order.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self.__constructor__(
- query_compiler=self._query_compiler.nsmallest(
- n=n, columns=columns, keep=keep
- )
- )
-
- def unstack(
- self,
- level: int | str | list = -1,
- fill_value: int | str | dict = None,
- sort: bool = True,
- ):
- """
- Pivot a level of the (necessarily hierarchical) index labels.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # This ensures that non-pandas MultiIndex objects are caught.
- nlevels = self._query_compiler.nlevels()
- is_multiindex = nlevels > 1
-
- if not is_multiindex or (
- is_multiindex and is_list_like(level) and len(level) == nlevels
- ):
- return self._reduce_dimension(
- query_compiler=self._query_compiler.unstack(
- level, fill_value, sort, is_series_input=False
- )
- )
- else:
- return self.__constructor__(
- query_compiler=self._query_compiler.unstack(
- level, fill_value, sort, is_series_input=False
- )
- )
-
- def pivot(
- self,
- *,
- columns: Any,
- index: Any | NoDefault = no_default,
- values: Any | NoDefault = no_default,
- ):
- """
- Return reshaped DataFrame organized by given index / column values.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if index is no_default:
- index = None # pragma: no cover
- if values is no_default:
- values = None
-
- # if values is not specified, it should be the remaining columns not in
- # index or columns
- if values is None:
- values = list(self.columns)
- if index is not None:
- values = [v for v in values if v not in index]
- if columns is not None:
- values = [v for v in values if v not in columns]
-
- return self.__constructor__(
- query_compiler=self._query_compiler.pivot(
- index=index, columns=columns, values=values
- )
- )
-
- def pivot_table(
- self,
- values=None,
- index=None,
- columns=None,
- aggfunc="mean",
- fill_value=None,
- margins=False,
- dropna=True,
- margins_name="All",
- observed=False,
- sort=True,
- ):
- """
- Create a spreadsheet-style pivot table as a ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- result = self.__constructor__(
- query_compiler=self._query_compiler.pivot_table(
- index=index,
- values=values,
- columns=columns,
- aggfunc=aggfunc,
- fill_value=fill_value,
- margins=margins,
- dropna=dropna,
- margins_name=margins_name,
- observed=observed,
- sort=sort,
- )
- )
- return result
-
- @dataframe_not_implemented()
- @property
- def plot(
- self,
- x=None,
- y=None,
- kind="line",
- ax=None,
- subplots=False,
- sharex=None,
- sharey=False,
- layout=None,
- figsize=None,
- use_index=True,
- title=None,
- grid=None,
- legend=True,
- style=None,
- logx=False,
- logy=False,
- loglog=False,
- xticks=None,
- yticks=None,
- xlim=None,
- ylim=None,
- rot=None,
- fontsize=None,
- colormap=None,
- table=False,
- yerr=None,
- xerr=None,
- secondary_y=False,
- sort_columns=False,
- **kwargs,
- ): # noqa: PR01, RT01, D200
- """
- Make plots of ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._to_pandas().plot
-
- def pow(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "pow",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- @dataframe_not_implemented()
- def prod(
- self,
- axis=None,
- skipna=True,
- numeric_only=False,
- min_count=0,
- **kwargs,
- ): # noqa: PR01, RT01, D200
- """
- Return the product of the values over the requested axis.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- validate_bool_kwarg(skipna, "skipna", none_allowed=False)
- axis = self._get_axis_number(axis)
- axis_to_apply = self.columns if axis else self.index
- if (
- skipna is not False
- and numeric_only is None
- and min_count > len(axis_to_apply)
- ):
- new_index = self.columns if not axis else self.index
- return Series(
- [np.nan] * len(new_index), index=new_index, dtype=np.dtype("object")
- )
-
- data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True)
- if min_count > 1:
- return data._reduce_dimension(
- data._query_compiler.prod_min_count(
- axis=axis,
- skipna=skipna,
- numeric_only=numeric_only,
- min_count=min_count,
- **kwargs,
- )
- )
- return data._reduce_dimension(
- data._query_compiler.prod(
- axis=axis,
- skipna=skipna,
- numeric_only=numeric_only,
- min_count=min_count,
- **kwargs,
- )
- )
-
- product = prod
-
- def quantile(
- self,
- q: Scalar | ListLike = 0.5,
- axis: Axis = 0,
- numeric_only: bool = False,
- interpolation: Literal[
- "linear", "lower", "higher", "midpoint", "nearest"
- ] = "linear",
- method: Literal["single", "table"] = "single",
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().quantile(
- q=q,
- axis=axis,
- numeric_only=numeric_only,
- interpolation=interpolation,
- method=method,
- )
-
- @dataframe_not_implemented()
- def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200
- """
- Query the columns of a ``DataFrame`` with a boolean expression.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- self._update_var_dicts_in_kwargs(expr, kwargs)
- self._validate_eval_query(expr, **kwargs)
- inplace = validate_bool_kwarg(inplace, "inplace")
- new_query_compiler = self._query_compiler.query(expr, **kwargs)
- return self._create_or_update_from_compiler(new_query_compiler, inplace)
-
- def rename(
- self,
- mapper: Renamer | None = None,
- *,
- index: Renamer | None = None,
- columns: Renamer | None = None,
- axis: Axis | None = None,
- copy: bool | None = None,
- inplace: bool = False,
- level: Level | None = None,
- errors: IgnoreRaise = "ignore",
- ) -> DataFrame | None:
- """
- Alter axes labels.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- inplace = validate_bool_kwarg(inplace, "inplace")
- if mapper is None and index is None and columns is None:
- raise TypeError("must pass an index to rename")
-
- if index is not None or columns is not None:
- if axis is not None:
- raise TypeError(
- "Cannot specify both 'axis' and any of 'index' or 'columns'"
- )
- elif mapper is not None:
- raise TypeError(
- "Cannot specify both 'mapper' and any of 'index' or 'columns'"
- )
- else:
- # use the mapper argument
- if axis and self._get_axis_number(axis) == 1:
- columns = mapper
- else:
- index = mapper
-
- if copy is not None:
- WarningMessage.ignored_argument(
- operation="dataframe.rename",
- argument="copy",
- message="copy parameter has been ignored with Snowflake execution engine",
- )
-
- if isinstance(index, dict):
- index = Series(index)
-
- new_qc = self._query_compiler.rename(
- index_renamer=index, columns_renamer=columns, level=level, errors=errors
- )
- return self._create_or_update_from_compiler(
- new_query_compiler=new_qc, inplace=inplace
- )
-
- def reindex(
- self,
- labels=None,
- index=None,
- columns=None,
- axis=None,
- method=None,
- copy=None,
- level=None,
- fill_value=np.nan,
- limit=None,
- tolerance=None,
- ): # noqa: PR01, RT01, D200
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
-
- axis = self._get_axis_number(axis)
- if axis == 0 and labels is not None:
- index = labels
- elif labels is not None:
- columns = labels
- return super().reindex(
- index=index,
- columns=columns,
- method=method,
- copy=copy,
- level=level,
- fill_value=fill_value,
- limit=limit,
- tolerance=tolerance,
- )
-
- @dataframe_not_implemented()
- def reindex_like(
- self,
- other,
- method=None,
- copy: bool | None = None,
- limit=None,
- tolerance=None,
- ) -> DataFrame: # pragma: no cover
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if copy is None:
- copy = True
- # docs say "Same as calling .reindex(index=other.index, columns=other.columns,...).":
- # https://pandas.pydata.org/pandas-docs/version/1.4/reference/api/pandas.DataFrame.reindex_like.html
- return self.reindex(
- index=other.index,
- columns=other.columns,
- method=method,
- copy=copy,
- limit=limit,
- tolerance=tolerance,
- )
-
- def replace(
- self,
- to_replace=None,
- value=no_default,
- inplace: bool = False,
- limit=None,
- regex: bool = False,
- method: str | NoDefault = no_default,
- ):
- """
- Replace values given in `to_replace` with `value`.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- inplace = validate_bool_kwarg(inplace, "inplace")
- new_query_compiler = self._query_compiler.replace(
- to_replace=to_replace,
- value=value,
- limit=limit,
- regex=regex,
- method=method,
- )
- return self._create_or_update_from_compiler(new_query_compiler, inplace)
-
- def rfloordiv(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `rfloordiv`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "rfloordiv",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def radd(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get addition of ``DataFrame`` and `other`, element-wise (binary operator `radd`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "radd",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def rmod(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `rmod`).
- """
- return self._binary_op(
- "rmod",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200
- return super().round(decimals, args=args, **kwargs)
-
- def rpow(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "rpow",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def rsub(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `rsub`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "rsub",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- def rtruediv(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `rtruediv`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "rtruediv",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- rdiv = rtruediv
-
- def select_dtypes(
- self,
- include: ListLike | str | type | None = None,
- exclude: ListLike | str | type | None = None,
- ) -> DataFrame:
- """
- Return a subset of the ``DataFrame``'s columns based on the column dtypes.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # This line defers argument validation to pandas, which will raise errors on our behalf in cases
- # like if `include` and `exclude` are None, the same type is specified in both lists, or a string
- # dtype (as opposed to object) is specified.
- pandas.DataFrame().select_dtypes(include, exclude)
-
- if include and not is_list_like(include):
- include = [include]
- elif include is None:
- include = []
- if exclude and not is_list_like(exclude):
- exclude = [exclude]
- elif exclude is None:
- exclude = []
-
- sel = tuple(map(set, (include, exclude)))
-
- # The width of the np.int_/float_ alias differs between Windows and other platforms, so
- # we need to include a workaround.
- # https://github.com/numpy/numpy/issues/9464
- # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036
- def check_sized_number_infer_dtypes(dtype):
- if (isinstance(dtype, str) and dtype == "int") or (dtype is int):
- return [np.int32, np.int64]
- elif dtype == "float" or dtype is float:
- return [np.float64, np.float32]
- else:
- return [infer_dtype_from_object(dtype)]
-
- include, exclude = map(
- lambda x: set(
- itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x))
- ),
- sel,
- )
- # We need to index on column position rather than label in case of duplicates
- include_these = pandas.Series(not bool(include), index=range(len(self.columns)))
- exclude_these = pandas.Series(not bool(exclude), index=range(len(self.columns)))
-
- def is_dtype_instance_mapper(dtype):
- return functools.partial(issubclass, dtype.type)
-
- for i, dtype in enumerate(self.dtypes):
- if include:
- include_these[i] = any(map(is_dtype_instance_mapper(dtype), include))
- if exclude:
- exclude_these[i] = not any(
- map(is_dtype_instance_mapper(dtype), exclude)
- )
-
- dtype_indexer = include_these & exclude_these
- indicate = [i for i, should_keep in dtype_indexer.items() if should_keep]
- # We need to use iloc instead of drop in case of duplicate column names
- return self.iloc[:, indicate]
-
- def shift(
- self,
- periods: int | Sequence[int] = 1,
- freq=None,
- axis: Axis = 0,
- fill_value: Hashable = no_default,
- suffix: str | None = None,
- ) -> DataFrame:
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().shift(periods, freq, axis, fill_value, suffix)
-
- def set_index(
- self,
- keys: IndexLabel
- | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable],
- drop: bool = True,
- append: bool = False,
- inplace: bool = False,
- verify_integrity: bool = False,
- ) -> None | DataFrame:
- """
- Set the ``DataFrame`` index using existing columns.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- inplace = validate_bool_kwarg(inplace, "inplace")
- if not isinstance(keys, list):
- keys = [keys]
-
- # make sure key is either hashable, index, or series
- label_or_series = []
-
- missing = []
- columns = self.columns.tolist()
- for key in keys:
- raise_if_native_pandas_objects(key)
- if isinstance(key, pd.Series):
- label_or_series.append(key._query_compiler)
- elif isinstance(key, (np.ndarray, list, Iterator)):
- label_or_series.append(pd.Series(key)._query_compiler)
- elif isinstance(key, (pd.Index, pandas.MultiIndex)):
- label_or_series += [
- s._query_compiler for s in self._to_series_list(key)
- ]
- else:
- if not is_hashable(key):
- raise TypeError(
- f'The parameter "keys" may be a column key, one-dimensional array, or a list '
- f"containing only valid column keys and one-dimensional arrays. Received column "
- f"of type {type(key)}"
- )
- label_or_series.append(key)
- found = key in columns
- if columns.count(key) > 1:
- raise ValueError(f"The column label '{key}' is not unique")
- elif not found:
- missing.append(key)
-
- if missing:
- raise KeyError(f"None of {missing} are in the columns")
-
- new_query_compiler = self._query_compiler.set_index(
- label_or_series, drop=drop, append=append
- )
-
- # TODO: SNOW-782633 improve this code once duplicate is supported
- # this needs to pull all index which is inefficient
- if verify_integrity and not new_query_compiler.index.is_unique:
- duplicates = new_query_compiler.index[
- new_query_compiler.index.to_pandas().duplicated()
- ].unique()
- raise ValueError(f"Index has duplicate keys: {duplicates}")
-
- return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace)
-
- sparse = CachedAccessor("sparse", SparseFrameAccessor)
-
- def squeeze(self, axis: Axis | None = None):
- """
- Squeeze 1 dimensional axis objects into scalars.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- axis = self._get_axis_number(axis) if axis is not None else None
- len_columns = self._query_compiler.get_axis_len(1)
- if axis == 1 and len_columns == 1:
- return Series(query_compiler=self._query_compiler)
- if axis in [0, None]:
- # get_axis_len(0) results in a sql query to count number of rows in current
- # dataframe. We should only compute len_index if axis is 0 or None.
- len_index = len(self)
- if axis is None and (len_columns == 1 or len_index == 1):
- return Series(query_compiler=self._query_compiler).squeeze()
- if axis == 0 and len_index == 1:
- return Series(query_compiler=self.T._query_compiler)
- return self.copy()
-
- def stack(
- self,
- level: int | str | list = -1,
- dropna: bool | NoDefault = no_default,
- sort: bool | NoDefault = no_default,
- future_stack: bool = False, # ignored
- ):
- """
- Stack the prescribed level(s) from columns to index.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if future_stack is not False:
- WarningMessage.ignored_argument( # pragma: no cover
- operation="DataFrame.stack",
- argument="future_stack",
- message="future_stack parameter has been ignored with Snowflake execution engine",
- )
- if dropna is NoDefault:
- dropna = True # pragma: no cover
- if sort is NoDefault:
- sort = True # pragma: no cover
-
- # This ensures that non-pandas MultiIndex objects are caught.
- is_multiindex = len(self.columns.names) > 1
- if not is_multiindex or (
- is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels
- ):
- return self._reduce_dimension(
- query_compiler=self._query_compiler.stack(level, dropna, sort)
- )
- else:
- return self.__constructor__(
- query_compiler=self._query_compiler.stack(level, dropna, sort)
- )
-
- def sub(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `sub`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "sub",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- subtract = sub
-
- @dataframe_not_implemented()
- def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200
- """
- Write a ``DataFrame`` to the binary Feather format.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(pandas.DataFrame.to_feather, path, **kwargs)
-
- @dataframe_not_implemented()
- def to_gbq(
- self,
- destination_table,
- project_id=None,
- chunksize=None,
- reauth=False,
- if_exists="fail",
- auth_local_webserver=True,
- table_schema=None,
- location=None,
- progress_bar=True,
- credentials=None,
- ): # pragma: no cover # noqa: PR01, RT01, D200
- """
- Write a ``DataFrame`` to a Google BigQuery table.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functionsf
- return self._default_to_pandas(
- pandas.DataFrame.to_gbq,
- destination_table,
- project_id=project_id,
- chunksize=chunksize,
- reauth=reauth,
- if_exists=if_exists,
- auth_local_webserver=auth_local_webserver,
- table_schema=table_schema,
- location=location,
- progress_bar=progress_bar,
- credentials=credentials,
- )
-
- @dataframe_not_implemented()
- def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.to_orc,
- path=path,
- engine=engine,
- index=index,
- engine_kwargs=engine_kwargs,
- )
-
- @dataframe_not_implemented()
- def to_html(
- self,
- buf=None,
- columns=None,
- col_space=None,
- header=True,
- index=True,
- na_rep="NaN",
- formatters=None,
- float_format=None,
- sparsify=None,
- index_names=True,
- justify=None,
- max_rows=None,
- max_cols=None,
- show_dimensions=False,
- decimal=".",
- bold_rows=True,
- classes=None,
- escape=True,
- notebook=False,
- border=None,
- table_id=None,
- render_links=False,
- encoding=None,
- ): # noqa: PR01, RT01, D200
- """
- Render a ``DataFrame`` as an HTML table.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.to_html,
- buf=buf,
- columns=columns,
- col_space=col_space,
- header=header,
- index=index,
- na_rep=na_rep,
- formatters=formatters,
- float_format=float_format,
- sparsify=sparsify,
- index_names=index_names,
- justify=justify,
- max_rows=max_rows,
- max_cols=max_cols,
- show_dimensions=show_dimensions,
- decimal=decimal,
- bold_rows=bold_rows,
- classes=classes,
- escape=escape,
- notebook=notebook,
- border=border,
- table_id=table_id,
- render_links=render_links,
- encoding=None,
- )
-
- @dataframe_not_implemented()
- def to_parquet(
- self,
- path=None,
- engine="auto",
- compression="snappy",
- index=None,
- partition_cols=None,
- storage_options: StorageOptions = None,
- **kwargs,
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- from snowflake.snowpark.modin.pandas.dispatching.factories.dispatcher import (
- FactoryDispatcher,
- )
-
- return FactoryDispatcher.to_parquet(
- self._query_compiler,
- path=path,
- engine=engine,
- compression=compression,
- index=index,
- partition_cols=partition_cols,
- storage_options=storage_options,
- **kwargs,
- )
-
- @dataframe_not_implemented()
- def to_period(
- self, freq=None, axis=0, copy=True
- ): # pragma: no cover # noqa: PR01, RT01, D200
- """
- Convert ``DataFrame`` from ``DatetimeIndex`` to ``PeriodIndex``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().to_period(freq=freq, axis=axis, copy=copy)
-
- @dataframe_not_implemented()
- def to_records(
- self, index=True, column_dtypes=None, index_dtypes=None
- ): # noqa: PR01, RT01, D200
- """
- Convert ``DataFrame`` to a NumPy record array.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.to_records,
- index=index,
- column_dtypes=column_dtypes,
- index_dtypes=index_dtypes,
- )
-
- @dataframe_not_implemented()
- def to_stata(
- self,
- path: FilePath | WriteBuffer[bytes],
- convert_dates: dict[Hashable, str] | None = None,
- write_index: bool = True,
- byteorder: str | None = None,
- time_stamp: datetime.datetime | None = None,
- data_label: str | None = None,
- variable_labels: dict[Hashable, str] | None = None,
- version: int | None = 114,
- convert_strl: Sequence[Hashable] | None = None,
- compression: CompressionOptions = "infer",
- storage_options: StorageOptions = None,
- *,
- value_labels: dict[Hashable, dict[float | int, str]] | None = None,
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.to_stata,
- path,
- convert_dates=convert_dates,
- write_index=write_index,
- byteorder=byteorder,
- time_stamp=time_stamp,
- data_label=data_label,
- variable_labels=variable_labels,
- version=version,
- convert_strl=convert_strl,
- compression=compression,
- storage_options=storage_options,
- value_labels=value_labels,
- )
-
- @dataframe_not_implemented()
- def to_xml(
- self,
- path_or_buffer=None,
- index=True,
- root_name="data",
- row_name="row",
- na_rep=None,
- attr_cols=None,
- elem_cols=None,
- namespaces=None,
- prefix=None,
- encoding="utf-8",
- xml_declaration=True,
- pretty_print=True,
- parser="lxml",
- stylesheet=None,
- compression="infer",
- storage_options=None,
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self.__constructor__(
- query_compiler=self._query_compiler.default_to_pandas(
- pandas.DataFrame.to_xml,
- path_or_buffer=path_or_buffer,
- index=index,
- root_name=root_name,
- row_name=row_name,
- na_rep=na_rep,
- attr_cols=attr_cols,
- elem_cols=elem_cols,
- namespaces=namespaces,
- prefix=prefix,
- encoding=encoding,
- xml_declaration=xml_declaration,
- pretty_print=pretty_print,
- parser=parser,
- stylesheet=stylesheet,
- compression=compression,
- storage_options=storage_options,
- )
- )
-
- def to_dict(
- self,
- orient: Literal[
- "dict", "list", "series", "split", "tight", "records", "index"
- ] = "dict",
- into: type[dict] = dict,
- ) -> dict | list[dict]:
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._to_pandas().to_dict(orient=orient, into=into)
-
- def to_timestamp(
- self, freq=None, how="start", axis=0, copy=True
- ): # noqa: PR01, RT01, D200
- """
- Cast to DatetimeIndex of timestamps, at *beginning* of period.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().to_timestamp(freq=freq, how=how, axis=axis, copy=copy)
-
- def truediv(
- self, other, axis="columns", level=None, fill_value=None
- ): # noqa: PR01, RT01, D200
- """
- Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `truediv`).
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op(
- "truediv",
- other,
- axis=axis,
- level=level,
- fill_value=fill_value,
- )
-
- div = divide = truediv
-
- def update(
- self, other, join="left", overwrite=True, filter_func=None, errors="ignore"
- ): # noqa: PR01, RT01, D200
- """
- Modify in place using non-NA values from another ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if not isinstance(other, DataFrame):
- other = self.__constructor__(other)
- query_compiler = self._query_compiler.df_update(
- other._query_compiler,
- join=join,
- overwrite=overwrite,
- filter_func=filter_func,
- errors=errors,
- )
- self._update_inplace(new_query_compiler=query_compiler)
-
- def diff(
- self,
- periods: int = 1,
- axis: Axis = 0,
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().diff(
- periods=periods,
- axis=axis,
- )
-
- def drop(
- self,
- labels: IndexLabel = None,
- axis: Axis = 0,
- index: IndexLabel = None,
- columns: IndexLabel = None,
- level: Level = None,
- inplace: bool = False,
- errors: IgnoreRaise = "raise",
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().drop(
- labels=labels,
- axis=axis,
- index=index,
- columns=columns,
- level=level,
- inplace=inplace,
- errors=errors,
- )
-
- def value_counts(
- self,
- subset: Sequence[Hashable] | None = None,
- normalize: bool = False,
- sort: bool = True,
- ascending: bool = False,
- dropna: bool = True,
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return Series(
- query_compiler=self._query_compiler.value_counts(
- subset=subset,
- normalize=normalize,
- sort=sort,
- ascending=ascending,
- dropna=dropna,
- ),
- name="proportion" if normalize else "count",
- )
-
- def mask(
- self,
- cond: DataFrame | Series | Callable | AnyArrayLike,
- other: DataFrame | Series | Callable | Scalar | None = np.nan,
- *,
- inplace: bool = False,
- axis: Axis | None = None,
- level: Level | None = None,
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if isinstance(other, Series) and axis is None:
- raise ValueError(
- "df.mask requires an axis parameter (0 or 1) when given a Series"
- )
-
- return super().mask(
- cond,
- other=other,
- inplace=inplace,
- axis=axis,
- level=level,
- )
-
- def where(
- self,
- cond: DataFrame | Series | Callable | AnyArrayLike,
- other: DataFrame | Series | Callable | Scalar | None = np.nan,
- *,
- inplace: bool = False,
- axis: Axis | None = None,
- level: Level | None = None,
- ):
- """
- Replace values where the condition is False.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if isinstance(other, Series) and axis is None:
- raise ValueError(
- "df.where requires an axis parameter (0 or 1) when given a Series"
- )
-
- return super().where(
- cond,
- other=other,
- inplace=inplace,
- axis=axis,
- level=level,
- )
-
- @dataframe_not_implemented()
- def xs(self, key, axis=0, level=None, drop_level=True): # noqa: PR01, RT01, D200
- """
- Return cross-section from the ``DataFrame``.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._default_to_pandas(
- pandas.DataFrame.xs, key, axis=axis, level=level, drop_level=drop_level
- )
-
- def set_axis(
- self,
- labels: IndexLabel,
- *,
- axis: Axis = 0,
- copy: bool | NoDefault = no_default, # ignored
- ):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if not is_scalar(axis):
- raise TypeError(f"{type(axis).__name__} is not a valid type for axis.")
- return super().set_axis(
- labels=labels,
- # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df.
- axis=pandas.DataFrame._get_axis_name(axis),
- copy=copy,
- )
-
- def __getattr__(self, key):
- """
- Return item identified by `key`.
-
- Parameters
- ----------
- key : hashable
- Key to get.
-
- Returns
- -------
- Any
-
- Notes
- -----
- First try to use `__getattribute__` method. If it fails
- try to get `key` from ``DataFrame`` fields.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- try:
- return object.__getattribute__(self, key)
- except AttributeError as err:
- if key not in _ATTRS_NO_LOOKUP and key in self.columns:
- return self[key]
- raise err
-
- def __setattr__(self, key, value):
- """
- Set attribute `value` identified by `key`.
-
- Parameters
- ----------
- key : hashable
- Key to set.
- value : Any
- Value to set.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # While we let users assign to a column labeled "x" with "df.x" , there
- # are some attributes that we should assume are NOT column names and
- # therefore should follow the default Python object assignment
- # behavior. These are:
- # - anything in self.__dict__. This includes any attributes that the
- # user has added to the dataframe with, e.g., `df.c = 3`, and
- # any attribute that Modin has added to the frame, e.g.
- # `_query_compiler` and `_siblings`
- # - `_query_compiler`, which Modin initializes before it appears in
- # __dict__
- # - `_siblings`, which Modin initializes before it appears in __dict__
- # - `_cache`, which pandas.cache_readonly uses to cache properties
- # before it appears in __dict__.
- if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__:
- pass
- elif key in self and key not in dir(self):
- self.__setitem__(key, value)
- # Note: return immediately so we don't keep this `key` as dataframe state.
- # `__getattr__` will return the columns not present in `dir(self)`, so we do not need
- # to manually track this state in the `dir`.
- return
- elif is_list_like(value) and key not in ["index", "columns"]:
- WarningMessage.single_warning(
- SET_DATAFRAME_ATTRIBUTE_WARNING
- ) # pragma: no cover
- object.__setattr__(self, key, value)
-
- def __setitem__(self, key: Any, value: Any):
- """
- Set attribute `value` identified by `key`.
-
- Args:
- key: Key to set
- value: Value to set
-
- Note:
- In the case where value is any list like or array, pandas checks the array length against the number of rows
- of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw
- a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if
- the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use
- enlargement filling with the last value in the array.
-
- Returns:
- None
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- key = apply_if_callable(key, self)
- if isinstance(key, DataFrame) or (
- isinstance(key, np.ndarray) and len(key.shape) == 2
- ):
- # This case uses mask's codepath to perform the set, but
- # we need to duplicate the code here since we are passing
- # an additional kwarg `cond_fillna_with_true` to the QC here.
- # We need this additional kwarg, since if df.shape
- # and key.shape do not align (i.e. df has more rows),
- # mask's codepath would mask the additional rows in df
- # while for setitem, we need to keep the original values.
- if not isinstance(key, DataFrame):
- if key.dtype != bool:
- raise TypeError(
- "Must pass DataFrame or 2-d ndarray with boolean values only"
- )
- key = DataFrame(key)
- key._query_compiler._shape_hint = "array"
-
- if value is not None:
- value = apply_if_callable(value, self)
-
- if isinstance(value, np.ndarray):
- value = DataFrame(value)
- value._query_compiler._shape_hint = "array"
- elif isinstance(value, pd.Series):
- # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this
- # error instead, since it is more descriptive.
- raise ValueError(
- "setitem with a 2D key does not support Series values."
- )
-
- if isinstance(value, BasePandasDataset):
- value = value._query_compiler
-
- query_compiler = self._query_compiler.mask(
- cond=key._query_compiler,
- other=value,
- axis=None,
- level=None,
- cond_fillna_with_true=True,
- )
-
- return self._create_or_update_from_compiler(query_compiler, inplace=True)
-
- # Error Checking:
- if (isinstance(key, pd.Series) or is_list_like(key)) and (
- isinstance(value, range)
- ):
- raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE)
- elif isinstance(value, slice):
- # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value.
- raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE)
-
- # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column
- # key.
- index, columns = slice(None), key
- index_is_bool_indexer = False
- if isinstance(key, slice):
- if is_integer(key.start) and is_integer(key.stop):
- # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as
- # df.iloc[1:2, :] = val
- self.iloc[key] = value
- return
- index, columns = key, slice(None)
- elif isinstance(key, pd.Series):
- if is_bool_dtype(key.dtype):
- index, columns = key, slice(None)
- index_is_bool_indexer = True
- elif is_bool_indexer(key):
- index, columns = pd.Series(key), slice(None)
- index_is_bool_indexer = True
-
- # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case
- # we have to explicitly set matching_item_columns_by_label to False for setitem.
- index = index._query_compiler if isinstance(index, BasePandasDataset) else index
- columns = (
- columns._query_compiler
- if isinstance(columns, BasePandasDataset)
- else columns
- )
- from .indexing import is_2d_array
-
- matching_item_rows_by_label = not is_2d_array(value)
- if is_2d_array(value):
- value = DataFrame(value)
- item = value._query_compiler if isinstance(value, BasePandasDataset) else value
- new_qc = self._query_compiler.set_2d_labels(
- index,
- columns,
- item,
- # setitem always matches item by position
- matching_item_columns_by_label=False,
- matching_item_rows_by_label=matching_item_rows_by_label,
- index_is_bool_indexer=index_is_bool_indexer,
- # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling
- # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the
- # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have
- # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns
- # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B",
- # "X", "X".
- deduplicate_columns=True,
- )
- return self._update_inplace(new_query_compiler=new_qc)
-
- def abs(self):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().abs()
-
- def __and__(self, other):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("__and__", other, axis=1)
-
- def __rand__(self, other):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("__rand__", other, axis=1)
-
- def __or__(self, other):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("__or__", other, axis=1)
-
- def __ror__(self, other):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._binary_op("__ror__", other, axis=1)
-
- def __neg__(self):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().__neg__()
-
- def __iter__(self):
- """
- Iterate over info axis.
-
- Returns
- -------
- iterable
- Iterator of the columns names.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return iter(self.columns)
-
- def __contains__(self, key):
- """
- Check if `key` in the ``DataFrame.columns``.
-
- Parameters
- ----------
- key : hashable
- Key to check the presence in the columns.
-
- Returns
- -------
- bool
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self.columns.__contains__(key)
-
- def __round__(self, decimals=0):
- """
- Round each value in a ``DataFrame`` to the given number of decimals.
-
- Parameters
- ----------
- decimals : int, default: 0
- Number of decimal places to round to.
-
- Returns
- -------
- DataFrame
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return super().round(decimals)
-
- @dataframe_not_implemented()
- def __delitem__(self, key):
- """
- Delete item identified by `key` label.
-
- Parameters
- ----------
- key : hashable
- Key to delete.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if key not in self:
- raise KeyError(key)
- self._update_inplace(new_query_compiler=self._query_compiler.delitem(key))
-
- __add__ = add
- __iadd__ = add # pragma: no cover
- __radd__ = radd
- __mul__ = mul
- __imul__ = mul # pragma: no cover
- __rmul__ = rmul
- __pow__ = pow
- __ipow__ = pow # pragma: no cover
- __rpow__ = rpow
- __sub__ = sub
- __isub__ = sub # pragma: no cover
- __rsub__ = rsub
- __floordiv__ = floordiv
- __ifloordiv__ = floordiv # pragma: no cover
- __rfloordiv__ = rfloordiv
- __truediv__ = truediv
- __itruediv__ = truediv # pragma: no cover
- __rtruediv__ = rtruediv
- __mod__ = mod
- __imod__ = mod # pragma: no cover
- __rmod__ = rmod
- __rdiv__ = rdiv
-
- def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True):
- """
- Get a Modin DataFrame that implements the dataframe exchange protocol.
-
- See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html.
-
- Parameters
- ----------
- nan_as_null : bool, default: False
- A keyword intended for the consumer to tell the producer
- to overwrite null values in the data with ``NaN`` (or ``NaT``).
- This currently has no effect; once support for nullable extension
- dtypes is added, this value should be propagated to columns.
- allow_copy : bool, default: True
- A keyword that defines whether or not the library is allowed
- to make a copy of the data. For example, copying data would be necessary
- if a library supports strided buffers, given that this protocol
- specifies contiguous buffers. Currently, if the flag is set to ``False``
- and a copy is needed, a ``RuntimeError`` will be raised.
-
- Returns
- -------
- ProtocolDataframe
- A dataframe object following the dataframe protocol specification.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- ErrorMessage.not_implemented(
- "Snowpark pandas does not support the DataFrame interchange "
- + "protocol method `__dataframe__`. To use Snowpark pandas "
- + "DataFrames with third-party libraries that try to call the "
- + "`__dataframe__` method, please convert this Snowpark pandas "
- + "DataFrame to pandas with `to_pandas()`."
- )
-
- return self._query_compiler.to_dataframe(
- nan_as_null=nan_as_null, allow_copy=allow_copy
- )
-
- @dataframe_not_implemented()
- @property
- def attrs(self): # noqa: RT01, D200
- """
- Return dictionary of global attributes of this dataset.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- def attrs(df):
- return df.attrs
-
- return self._default_to_pandas(attrs)
-
- @dataframe_not_implemented()
- @property
- def style(self): # noqa: RT01, D200
- """
- Return a Styler object.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- def style(df):
- """Define __name__ attr because properties do not have it."""
- return df.style
-
- return self._default_to_pandas(style)
-
- def isin(
- self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike]
- ) -> DataFrame:
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if isinstance(values, dict):
- return super().isin(values)
- elif isinstance(values, Series):
- # Note: pandas performs explicit is_unique check here, deactivated for performance reasons.
- # if not values.index.is_unique:
- # raise ValueError("cannot compute isin with a duplicate axis.")
- return self.__constructor__(
- query_compiler=self._query_compiler.isin(values._query_compiler)
- )
- elif isinstance(values, DataFrame):
- # Note: pandas performs explicit is_unique check here, deactivated for performance reasons.
- # if not (values.columns.is_unique and values.index.is_unique):
- # raise ValueError("cannot compute isin with a duplicate axis.")
- return self.__constructor__(
- query_compiler=self._query_compiler.isin(values._query_compiler)
- )
- else:
- if not is_list_like(values):
- # throw pandas compatible error
- raise TypeError(
- "only list-like or dict-like objects are allowed "
- f"to be passed to {self.__class__.__name__}.isin(), "
- f"you passed a '{type(values).__name__}'"
- )
- return super().isin(values)
-
- def _create_or_update_from_compiler(self, new_query_compiler, inplace=False):
- """
- Return or update a ``DataFrame`` with given `new_query_compiler`.
-
- Parameters
- ----------
- new_query_compiler : PandasQueryCompiler
- QueryCompiler to use to manage the data.
- inplace : bool, default: False
- Whether or not to perform update or creation inplace.
-
- Returns
- -------
- DataFrame or None
- None if update was done, ``DataFrame`` otherwise.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- assert (
- isinstance(new_query_compiler, type(self._query_compiler))
- or type(new_query_compiler) in self._query_compiler.__class__.__bases__
- ), f"Invalid Query Compiler object: {type(new_query_compiler)}"
- if not inplace:
- return self.__constructor__(query_compiler=new_query_compiler)
- else:
- self._update_inplace(new_query_compiler=new_query_compiler)
-
- def _get_numeric_data(self, axis: int):
- """
- Grab only numeric data from ``DataFrame``.
-
- Parameters
- ----------
- axis : {0, 1}
- Axis to inspect on having numeric types only.
-
- Returns
- -------
- DataFrame
- ``DataFrame`` with numeric data.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # pandas ignores `numeric_only` if `axis` is 1, but we do have to drop
- # non-numeric columns if `axis` is 0.
- if axis != 0:
- return self
- return self.drop(
- columns=[
- i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i])
- ]
- )
-
- def _validate_dtypes(self, numeric_only=False):
- """
- Check that all the dtypes are the same.
-
- Parameters
- ----------
- numeric_only : bool, default: False
- Whether or not to allow only numeric data.
- If True and non-numeric data is found, exception
- will be raised.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- dtype = self.dtypes[0]
- for t in self.dtypes:
- if numeric_only and not is_numeric_dtype(t):
- raise TypeError(f"{t} is not a numeric data type")
- elif not numeric_only and t != dtype:
- raise TypeError(f"Cannot compare type '{t}' with type '{dtype}'")
-
- def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False):
- """
- Validate data dtype for `sum`, `prod` and `mean` methods.
-
- Parameters
- ----------
- axis : {0, 1}
- Axis to validate over.
- numeric_only : bool
- Whether or not to allow only numeric data.
- If True and non-numeric data is found, exception
- will be raised.
- ignore_axis : bool, default: False
- Whether or not to ignore `axis` parameter.
-
- Returns
- -------
- DataFrame
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- # We cannot add datetime types, so if we are summing a column with
- # dtype datetime64 and cannot ignore non-numeric types, we must throw a
- # TypeError.
- if (
- not axis
- and numeric_only is False
- and any(dtype == np.dtype("datetime64[ns]") for dtype in self.dtypes)
- ):
- raise TypeError("Cannot add Timestamp Types")
-
- # If our DataFrame has both numeric and non-numeric dtypes then
- # operations between these types do not make sense and we must raise a
- # TypeError. The exception to this rule is when there are datetime and
- # timedelta objects, in which case we proceed with the comparison
- # without ignoring any non-numeric types. We must check explicitly if
- # numeric_only is False because if it is None, it will default to True
- # if the operation fails with mixed dtypes.
- if (
- (axis or ignore_axis)
- and numeric_only is False
- and np.unique([is_numeric_dtype(dtype) for dtype in self.dtypes]).size == 2
- ):
- # check if there are columns with dtypes datetime or timedelta
- if all(
- dtype != np.dtype("datetime64[ns]")
- and dtype != np.dtype("timedelta64[ns]")
- for dtype in self.dtypes
- ):
- raise TypeError("Cannot operate on Numeric and Non-Numeric Types")
-
- return self._get_numeric_data(axis) if numeric_only else self
-
- def _to_pandas(
- self,
- *,
- statement_params: dict[str, str] | None = None,
- **kwargs: Any,
- ) -> pandas.DataFrame:
- """
- Convert Snowpark pandas DataFrame to pandas DataFrame
-
- Args:
- statement_params: Dictionary of statement level parameters to be set while executing this action.
-
- Returns:
- pandas DataFrame
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._query_compiler.to_pandas(
- statement_params=statement_params, **kwargs
- )
-
- def _validate_eval_query(self, expr, **kwargs):
- """
- Validate the arguments of ``eval`` and ``query`` functions.
-
- Parameters
- ----------
- expr : str
- The expression to evaluate. This string cannot contain any
- Python statements, only Python expressions.
- **kwargs : dict
- Optional arguments of ``eval`` and ``query`` functions.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- if isinstance(expr, str) and expr == "":
- raise ValueError("expr cannot be an empty string")
-
- if isinstance(expr, str) and "not" in expr:
- if "parser" in kwargs and kwargs["parser"] == "python":
- ErrorMessage.not_implemented( # pragma: no cover
- "Snowpark pandas does not yet support 'not' in the "
- + "expression for the methods `DataFrame.eval` or "
- + "`DataFrame.query`"
- )
-
- def _reduce_dimension(self, query_compiler):
- """
- Reduce the dimension of data from the `query_compiler`.
-
- Parameters
- ----------
- query_compiler : BaseQueryCompiler
- Query compiler to retrieve the data.
-
- Returns
- -------
- Series
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return Series(query_compiler=query_compiler)
-
- def _set_axis_name(self, name, axis=0, inplace=False):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- axis = self._get_axis_number(axis)
- renamed = self if inplace else self.copy()
- if axis == 0:
- renamed.index = renamed.index.set_names(name)
- else:
- renamed.columns = renamed.columns.set_names(name)
- if not inplace:
- return renamed
-
- def _to_datetime(self, **kwargs):
- """
- Convert `self` to datetime.
-
- Parameters
- ----------
- **kwargs : dict
- Optional arguments to use during query compiler's
- `to_datetime` invocation.
-
- Returns
- -------
- Series of datetime64 dtype
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return self._reduce_dimension(
- query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs)
- )
-
- # Persistance support methods - BEGIN
- @classmethod
- def _inflate_light(cls, query_compiler):
- """
- Re-creates the object from previously-serialized lightweight representation.
-
- The method is used for faster but not disk-storable persistence.
-
- Parameters
- ----------
- query_compiler : BaseQueryCompiler
- Query compiler to use for object re-creation.
-
- Returns
- -------
- DataFrame
- New ``DataFrame`` based on the `query_compiler`.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return cls(query_compiler=query_compiler)
-
- @classmethod
- def _inflate_full(cls, pandas_df):
- """
- Re-creates the object from previously-serialized disk-storable representation.
-
- Parameters
- ----------
- pandas_df : pandas.DataFrame
- Data to use for object re-creation.
-
- Returns
- -------
- DataFrame
- New ``DataFrame`` based on the `pandas_df`.
- """
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- return cls(data=from_pandas(pandas_df))
-
- @dataframe_not_implemented()
- def __reduce__(self):
- # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
- self._query_compiler.finalize()
- # if PersistentPickle.get():
- # return self._inflate_full, (self._to_pandas(),)
- return self._inflate_light, (self._query_compiler,)
-
- # Persistance support methods - END
diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py
index d5d158373de..5024d0618ac 100644
--- a/src/snowflake/snowpark/modin/pandas/general.py
+++ b/src/snowflake/snowpark/modin/pandas/general.py
@@ -31,7 +31,7 @@
import numpy as np
import pandas
import pandas.core.common as common
-from modin.pandas import Series
+from modin.pandas import DataFrame, Series
from modin.pandas.base import BasePandasDataset
from pandas import IntervalIndex, NaT, Timedelta, Timestamp
from pandas._libs import NaTType, lib
@@ -65,7 +65,6 @@
# add this line to make doctests runnable
from snowflake.snowpark.modin import pandas as pd # noqa: F401
-from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.utils import (
is_scalar,
raise_if_native_pandas_objects,
@@ -92,10 +91,9 @@
# linking to `snowflake.snowpark.DataFrame`, we need to explicitly
# qualify return types in this file with `modin.pandas.DataFrame`.
# SNOW-1233342: investigate how to fix these links without using absolute paths
+ import modin
from modin.core.storage_formats import BaseQueryCompiler # pragma: no cover
- import snowflake # pragma: no cover
-
_logger = getLogger(__name__)
VALID_DATE_TYPE = Union[
@@ -137,8 +135,8 @@ def notna(obj): # noqa: PR01, RT01, D200
@snowpark_pandas_telemetry_standalone_function_decorator
def merge(
- left: snowflake.snowpark.modin.pandas.DataFrame | Series,
- right: snowflake.snowpark.modin.pandas.DataFrame | Series,
+ left: modin.pandas.DataFrame | Series,
+ right: modin.pandas.DataFrame | Series,
how: str | None = "inner",
on: IndexLabel | None = None,
left_on: None
@@ -414,7 +412,7 @@ def merge_asof(
tolerance: int | Timedelta | None = None,
allow_exact_matches: bool = True,
direction: str = "backward",
-) -> snowflake.snowpark.modin.pandas.DataFrame:
+) -> modin.pandas.DataFrame:
"""
Perform a merge by key distance.
@@ -1047,7 +1045,7 @@ def unique(values) -> np.ndarray:
>>> pd.unique([pd.Timestamp('2016-01-01', tz='US/Eastern')
... for _ in range(3)])
- array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')],
+ array([Timestamp('2016-01-01 00:00:00-0500', tz='UTC-05:00')],
dtype=object)
>>> pd.unique([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")])
@@ -1105,8 +1103,8 @@ def value_counts(
@snowpark_pandas_telemetry_standalone_function_decorator
def concat(
objs: (
- Iterable[snowflake.snowpark.modin.pandas.DataFrame | Series]
- | Mapping[Hashable, snowflake.snowpark.modin.pandas.DataFrame | Series]
+ Iterable[modin.pandas.DataFrame | Series]
+ | Mapping[Hashable, modin.pandas.DataFrame | Series]
),
axis: Axis = 0,
join: str = "outer",
@@ -1117,7 +1115,7 @@ def concat(
verify_integrity: bool = False,
sort: bool = False,
copy: bool = True,
-) -> snowflake.snowpark.modin.pandas.DataFrame | Series:
+) -> modin.pandas.DataFrame | Series:
"""
Concatenate pandas objects along a particular axis.
@@ -1490,7 +1488,7 @@ def concat(
def to_datetime(
arg: DatetimeScalarOrArrayConvertible
| DictConvertible
- | snowflake.snowpark.modin.pandas.DataFrame
+ | modin.pandas.DataFrame
| Series,
errors: DateTimeErrorChoices = "raise",
dayfirst: bool = False,
@@ -1750,35 +1748,35 @@ def to_datetime(
DatetimeIndex(['2018-10-26 12:00:00', '2018-10-26 13:00:15'], dtype='datetime64[ns]', freq=None)
>>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500'])
- DatetimeIndex(['2018-10-26 10:00:00-07:00', '2018-10-26 11:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex(['2018-10-26 12:00:00-05:00', '2018-10-26 13:00:00-05:00'], dtype='datetime64[ns, UTC-05:00]', freq=None)
- Use right format to convert to timezone-aware type (Note that when call Snowpark
pandas API to_pandas() the timezone-aware output will always be converted to session timezone):
>>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500'], format="%Y-%m-%d %H:%M:%S %z")
- DatetimeIndex(['2018-10-26 10:00:00-07:00', '2018-10-26 11:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex(['2018-10-26 12:00:00-05:00', '2018-10-26 13:00:00-05:00'], dtype='datetime64[ns, UTC-05:00]', freq=None)
- Timezone-aware inputs *with mixed time offsets* (for example
issued from a timezone with daylight savings, such as Europe/Paris):
>>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100'])
- DatetimeIndex(['2020-10-24 17:00:00-07:00', '2020-10-24 20:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex([2020-10-25 02:00:00+02:00, 2020-10-25 04:00:00+01:00], dtype='object', freq=None)
>>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100'], format="%Y-%m-%d %H:%M:%S %z")
- DatetimeIndex(['2020-10-24 17:00:00-07:00', '2020-10-24 20:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex([2020-10-25 02:00:00+02:00, 2020-10-25 04:00:00+01:00], dtype='object', freq=None)
Setting ``utc=True`` makes sure always convert to timezone-aware outputs:
- Timezone-naive inputs are *localized* based on the session timezone
>>> pd.to_datetime(['2018-10-26 12:00', '2018-10-26 13:00'], utc=True)
- DatetimeIndex(['2018-10-26 05:00:00-07:00', '2018-10-26 06:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex(['2018-10-26 12:00:00+00:00', '2018-10-26 13:00:00+00:00'], dtype='datetime64[ns, UTC]', freq=None)
- Timezone-aware inputs are *converted* to session timezone
>>> pd.to_datetime(['2018-10-26 12:00:00 -0530', '2018-10-26 12:00:00 -0500'],
... utc=True)
- DatetimeIndex(['2018-10-26 10:30:00-07:00', '2018-10-26 10:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex(['2018-10-26 17:30:00+00:00', '2018-10-26 17:00:00+00:00'], dtype='datetime64[ns, UTC]', freq=None)
"""
# TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py
raise_if_native_pandas_objects(arg)
diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py
index c672f04da63..5da10d9b7a6 100644
--- a/src/snowflake/snowpark/modin/pandas/indexing.py
+++ b/src/snowflake/snowpark/modin/pandas/indexing.py
@@ -45,6 +45,7 @@
import pandas
from modin.pandas import Series
from modin.pandas.base import BasePandasDataset
+from modin.pandas.dataframe import DataFrame
from pandas._libs.tslibs import Resolution, parsing
from pandas._typing import AnyArrayLike, Scalar
from pandas.api.types import is_bool, is_list_like
@@ -61,7 +62,6 @@
import snowflake.snowpark.modin.pandas as pd
import snowflake.snowpark.modin.pandas.utils as frontend_utils
-from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.utils import is_scalar
from snowflake.snowpark.modin.plugin._internal.indexing_utils import (
MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE,
diff --git a/src/snowflake/snowpark/modin/pandas/io.py b/src/snowflake/snowpark/modin/pandas/io.py
index 25959212a18..b92e8ee3582 100644
--- a/src/snowflake/snowpark/modin/pandas/io.py
+++ b/src/snowflake/snowpark/modin/pandas/io.py
@@ -92,7 +92,7 @@
# below logic is to handle circular imports without errors
if TYPE_CHECKING: # pragma: no cover
- from .dataframe import DataFrame
+ from modin.pandas.dataframe import DataFrame
# TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available
@@ -106,7 +106,7 @@ class ModinObjects:
def DataFrame(cls):
"""Get ``modin.pandas.DataFrame`` class."""
if cls._dataframe is None:
- from .dataframe import DataFrame
+ from modin.pandas.dataframe import DataFrame
cls._dataframe = DataFrame
return cls._dataframe
diff --git a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py
index 3529355b81b..ee782f3cdf3 100644
--- a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py
+++ b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py
@@ -5,10 +5,9 @@
from collections.abc import Iterator
from typing import Any, Callable
+import modin.pandas.dataframe as DataFrame
import pandas
-import snowflake.snowpark.modin.pandas.dataframe as DataFrame
-
PARTITION_SIZE = 4096
diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py
index 3986e3d52a9..a48f16992d4 100644
--- a/src/snowflake/snowpark/modin/pandas/utils.py
+++ b/src/snowflake/snowpark/modin/pandas/utils.py
@@ -78,7 +78,7 @@ def from_non_pandas(df, index, columns, dtype):
new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype)
if new_qc is not None:
- from snowflake.snowpark.modin.pandas import DataFrame
+ from modin.pandas import DataFrame
return DataFrame(query_compiler=new_qc)
return new_qc
@@ -99,7 +99,7 @@ def from_pandas(df):
A new Modin DataFrame object.
"""
# from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher
- from snowflake.snowpark.modin.pandas import DataFrame
+ from modin.pandas import DataFrame
return DataFrame(query_compiler=FactoryDispatcher.from_pandas(df))
@@ -118,10 +118,11 @@ def from_arrow(at):
DataFrame
A new Modin DataFrame object.
"""
+ from modin.pandas import DataFrame
+
from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import (
FactoryDispatcher,
)
- from snowflake.snowpark.modin.pandas import DataFrame
return DataFrame(query_compiler=FactoryDispatcher.from_arrow(at))
@@ -142,10 +143,11 @@ def from_dataframe(df):
DataFrame
A new Modin DataFrame object.
"""
+ from modin.pandas import DataFrame
+
from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import (
FactoryDispatcher,
)
- from snowflake.snowpark.modin.pandas import DataFrame
return DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df))
@@ -226,7 +228,7 @@ def from_modin_frame_to_mi(df, sortorder=None, names=None):
pandas.MultiIndex
The pandas.MultiIndex representation of the given DataFrame.
"""
- from snowflake.snowpark.modin.pandas import DataFrame
+ from modin.pandas import DataFrame
if isinstance(df, DataFrame):
df = df._to_pandas()
diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py
index d3ac525572a..eceb9ca7d7f 100644
--- a/src/snowflake/snowpark/modin/plugin/__init__.py
+++ b/src/snowflake/snowpark/modin/plugin/__init__.py
@@ -69,6 +69,7 @@
inherit_modules = [
(docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset),
+ (docstrings.dataframe.DataFrame, modin.pandas.dataframe.DataFrame),
(docstrings.series.Series, modin.pandas.series.Series),
(docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods),
(
@@ -90,17 +91,3 @@
snowflake.snowpark._internal.utils.should_warn_dynamic_pivot_is_in_private_preview = (
False
)
-
-
-# TODO: SNOW-1504302: Modin upgrade - use Snowpark pandas DataFrame for isocalendar
-# OSS Modin's DatetimeProperties frontend class wraps the returned query compiler with `modin.pandas.DataFrame`.
-# Since we currently replace `pd.DataFrame` with our own Snowpark pandas DataFrame object, this causes errors
-# since OSS Modin explicitly imports its own DataFrame class here. This override can be removed once the frontend
-# DataFrame class is removed from our codebase.
-def isocalendar(self): # type: ignore
- from snowflake.snowpark.modin.pandas import DataFrame
-
- return DataFrame(query_compiler=self._query_compiler.dt_isocalendar())
-
-
-modin.pandas.series_utils.DatetimeProperties.isocalendar = isocalendar
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
index 01ccad8f430..0005df924db 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
@@ -9,7 +9,7 @@
from collections.abc import Hashable, Iterable
from functools import partial
from inspect import getmembers
-from types import BuiltinFunctionType
+from types import BuiltinFunctionType, MappingProxyType
from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union
import numpy as np
@@ -56,6 +56,7 @@
stddev,
stddev_pop,
sum as sum_,
+ trunc,
var_pop,
variance,
when,
@@ -65,6 +66,9 @@
OrderedDataFrame,
OrderingColumn,
)
+from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
+ TimedeltaType,
+)
from snowflake.snowpark.modin.plugin._internal.utils import (
from_pandas_label,
pandas_lit,
@@ -85,7 +89,7 @@
}
-def array_agg_keepna(
+def _array_agg_keepna(
column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn]
) -> Column:
"""
@@ -239,62 +243,63 @@ def _columns_coalescing_idxmax_idxmin_helper(
)
-# Map between the pandas input aggregation function (str or numpy function) and
-# the corresponding snowflake builtin aggregation function for axis=0. If any change
-# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and
-# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly.
-SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = {
- "count": count,
- "mean": mean,
- "min": min_,
- "max": max_,
- "idxmax": functools.partial(
- _columns_coalescing_idxmax_idxmin_helper, func="idxmax"
- ),
- "idxmin": functools.partial(
- _columns_coalescing_idxmax_idxmin_helper, func="idxmin"
- ),
- "sum": sum_,
- "median": median,
- "skew": skew,
- "std": stddev,
- "var": variance,
- "all": builtin("booland_agg"),
- "any": builtin("boolor_agg"),
- np.max: max_,
- np.min: min_,
- np.sum: sum_,
- np.mean: mean,
- np.median: median,
- np.std: stddev,
- np.var: variance,
- "array_agg": array_agg,
- "quantile": column_quantile,
- "nunique": count_distinct,
-}
-GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = (
- "min",
- "max",
- "sum",
- "mean",
- "median",
- "std",
- np.max,
- np.min,
- np.sum,
- np.mean,
- np.median,
- np.std,
-)
-GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = (
- "any",
- "all",
- "count",
- "idxmax",
- "idxmin",
- "size",
- "nunique",
-)
+class _SnowparkPandasAggregation(NamedTuple):
+ """
+ A representation of a Snowpark pandas aggregation.
+
+ This structure gives us a common representation for an aggregation that may
+ have multiple aliases, like "sum" and np.sum.
+ """
+
+ # This field tells whether if types of all the inputs of the function are
+ # the same instance of SnowparkPandasType, the type of the result is the
+ # same instance of SnowparkPandasType. Note that this definition applies
+ # whether the aggregation is on axis=0 or axis=1. For example, the sum of
+ # a single timedelta column on axis 0 is another timedelta column.
+ # Equivalently, the sum of two timedelta columns along axis 1 is also
+ # another timedelta column. Therefore, preserves_snowpark_pandas_types for
+ # sum would be True.
+ preserves_snowpark_pandas_types: bool
+
+ # This callable takes a single Snowpark column as input and aggregates the
+ # column on axis=0. If None, Snowpark pandas does not support this
+ # aggregation on axis=0.
+ axis_0_aggregation: Optional[Callable] = None
+
+ # This callable takes one or more Snowpark columns as input and
+ # the columns on axis=1 with skipna=True, i.e. not including nulls in the
+ # aggregation. If None, Snowpark pandas does not support this aggregation
+ # on axis=1 with skipna=True.
+ axis_1_aggregation_skipna: Optional[Callable] = None
+
+ # This callable takes one or more Snowpark columns as input and
+ # the columns on axis=1 with skipna=False, i.e. including nulls in the
+ # aggregation. If None, Snowpark pandas does not support this aggregation
+ # on axis=1 with skipna=False.
+ axis_1_aggregation_keepna: Optional[Callable] = None
+
+
+class SnowflakeAggFunc(NamedTuple):
+ """
+ A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType.
+ """
+
+ # The aggregation function in Snowpark.
+ # For aggregation on axis=0, this field should take a single Snowpark
+ # column and return the aggregated column.
+ # For aggregation on axis=1, this field should take an arbitrary number
+ # of Snowpark columns and return the aggregated column.
+ snowpark_aggregation: Callable
+
+ # This field tells whether if types of all the inputs of the function are
+ # the same instance of SnowparkPandasType, the type of the result is the
+ # same instance of SnowparkPandasType. Note that this definition applies
+ # whether the aggregation is on axis=0 or axis=1. For example, the sum of
+ # a single timedelta column on axis 0 is another timedelta column.
+ # Equivalently, the sum of two timedelta columns along axis 1 is also
+ # another timedelta column. Therefore, preserves_snowpark_pandas_types for
+ # sum would be True.
+ preserves_snowpark_pandas_types: bool
class AggFuncWithLabel(NamedTuple):
@@ -413,35 +418,143 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable:
return sum(builtin("zeroifnull")(col) for col in cols)
-# Map between the pandas input aggregation function (str or numpy function) and
-# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation
-# function may either be a builtin aggregation function, or a function taking in *arg columns
-# that then calls the appropriate builtin aggregations.
-SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = {
- "count": _columns_count,
- "sum": _columns_coalescing_sum,
- np.sum: _columns_coalescing_sum,
- "min": _columns_coalescing_min,
- "max": _columns_coalescing_max,
- "idxmax": _columns_coalescing_idxmax_idxmin_helper,
- "idxmin": _columns_coalescing_idxmax_idxmin_helper,
- np.min: _columns_coalescing_min,
- np.max: _columns_coalescing_max,
-}
+def _create_pandas_to_snowpark_pandas_aggregation_map(
+ pandas_functions: Iterable[AggFuncTypeBase],
+ snowpark_pandas_aggregation: _SnowparkPandasAggregation,
+) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]:
+ """
+ Create a map from the given pandas functions to the given _SnowparkPandasAggregation.
-# These functions are called instead if skipna=False
-SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = {
- "min": least,
- "max": greatest,
- "idxmax": _columns_coalescing_idxmax_idxmin_helper,
- "idxmin": _columns_coalescing_idxmax_idxmin_helper,
- # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark
- # sum_, since Snowpark sum_ gets the sum of all rows within a single column.
- "sum": lambda *cols: sum(cols),
- np.sum: lambda *cols: sum(cols),
- np.min: least,
- np.max: greatest,
-}
+ Args;
+ pandas_functions: The pandas functions that map to the given aggregation.
+ snowpark_pandas_aggregation: The aggregation to map to
+
+ Returns:
+ The map.
+ """
+ return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions})
+
+
+# Map between the pandas input aggregation function (str or numpy function) and
+# _SnowparkPandasAggregation representing information about applying the
+# aggregation in Snowpark pandas.
+_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[
+ AggFuncTypeBase, _SnowparkPandasAggregation
+] = MappingProxyType(
+ {
+ "count": _SnowparkPandasAggregation(
+ axis_0_aggregation=count,
+ axis_1_aggregation_skipna=_columns_count,
+ preserves_snowpark_pandas_types=False,
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("mean", np.mean),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=mean,
+ preserves_snowpark_pandas_types=True,
+ ),
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("min", np.min),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=min_,
+ axis_1_aggregation_keepna=least,
+ axis_1_aggregation_skipna=_columns_coalescing_min,
+ preserves_snowpark_pandas_types=True,
+ ),
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("max", np.max),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=max_,
+ axis_1_aggregation_keepna=greatest,
+ axis_1_aggregation_skipna=_columns_coalescing_max,
+ preserves_snowpark_pandas_types=True,
+ ),
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("sum", np.sum),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=sum_,
+ # IMPORTANT: count and sum use python builtin sum to invoke
+ # __add__ on each column rather than Snowpark sum_, since
+ # Snowpark sum_ gets the sum of all rows within a single column.
+ axis_1_aggregation_keepna=lambda *cols: sum(cols),
+ axis_1_aggregation_skipna=_columns_coalescing_sum,
+ preserves_snowpark_pandas_types=True,
+ ),
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("median", np.median),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=median,
+ preserves_snowpark_pandas_types=True,
+ ),
+ ),
+ "idxmax": _SnowparkPandasAggregation(
+ axis_0_aggregation=functools.partial(
+ _columns_coalescing_idxmax_idxmin_helper, func="idxmax"
+ ),
+ axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
+ axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
+ preserves_snowpark_pandas_types=False,
+ ),
+ "idxmin": _SnowparkPandasAggregation(
+ axis_0_aggregation=functools.partial(
+ _columns_coalescing_idxmax_idxmin_helper, func="idxmin"
+ ),
+ axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
+ axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
+ preserves_snowpark_pandas_types=False,
+ ),
+ "skew": _SnowparkPandasAggregation(
+ axis_0_aggregation=skew,
+ preserves_snowpark_pandas_types=True,
+ ),
+ "all": _SnowparkPandasAggregation(
+ # all() for a column with no non-null values is NULL in Snowflake, but True in pandas.
+ axis_0_aggregation=lambda c: coalesce(
+ builtin("booland_agg")(col(c)), pandas_lit(True)
+ ),
+ preserves_snowpark_pandas_types=False,
+ ),
+ "any": _SnowparkPandasAggregation(
+ # any() for a column with no non-null values is NULL in Snowflake, but False in pandas.
+ axis_0_aggregation=lambda c: coalesce(
+ builtin("boolor_agg")(col(c)), pandas_lit(False)
+ ),
+ preserves_snowpark_pandas_types=False,
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("std", np.std),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=stddev,
+ preserves_snowpark_pandas_types=True,
+ ),
+ ),
+ **_create_pandas_to_snowpark_pandas_aggregation_map(
+ ("var", np.var),
+ _SnowparkPandasAggregation(
+ axis_0_aggregation=variance,
+ # variance units are the square of the input column units, so
+ # variance does not preserve types.
+ preserves_snowpark_pandas_types=False,
+ ),
+ ),
+ "array_agg": _SnowparkPandasAggregation(
+ axis_0_aggregation=array_agg,
+ preserves_snowpark_pandas_types=False,
+ ),
+ "quantile": _SnowparkPandasAggregation(
+ axis_0_aggregation=column_quantile,
+ preserves_snowpark_pandas_types=True,
+ ),
+ "nunique": _SnowparkPandasAggregation(
+ axis_0_aggregation=count_distinct,
+ preserves_snowpark_pandas_types=False,
+ ),
+ }
+)
class AggregateColumnOpParameters(NamedTuple):
@@ -462,7 +575,7 @@ class AggregateColumnOpParameters(NamedTuple):
agg_snowflake_quoted_identifier: str
# the snowflake aggregation function to apply on the column
- snowflake_agg_func: Callable
+ snowflake_agg_func: SnowflakeAggFunc
# the columns specifying the order of rows in the column. This is only
# relevant for aggregations that depend on row order, e.g. summing a string
@@ -471,88 +584,108 @@ class AggregateColumnOpParameters(NamedTuple):
def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool:
- return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP
+ return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION
def get_snowflake_agg_func(
- agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0
-) -> Optional[Callable]:
+ agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1]
+) -> Optional[SnowflakeAggFunc]:
"""
Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function.
If no corresponding snowflake aggregation function can be found, return None.
"""
- if axis == 0:
- snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func)
- if snowflake_agg_func == stddev or snowflake_agg_func == variance:
- # for aggregation function std and var, we only support ddof = 0 or ddof = 1.
- # when ddof is 1, std is mapped to stddev, var is mapped to variance
- # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop
- # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1
- ddof = agg_kwargs.get("ddof", 1)
- if ddof != 1 and ddof != 0:
- return None
- if ddof == 0:
- return stddev_pop if snowflake_agg_func == stddev else var_pop
- elif snowflake_agg_func == column_quantile:
- interpolation = agg_kwargs.get("interpolation", "linear")
- q = agg_kwargs.get("q", 0.5)
- if interpolation not in ("linear", "nearest"):
- return None
- if not is_scalar(q):
- # SNOW-1062878 Because list-like q would return multiple rows, calling quantile
- # through the aggregate frontend in this manner is unsupported.
- return None
- return lambda col: column_quantile(col, interpolation, q)
- elif agg_func in ("all", "any"):
- # If there are no rows in the input frame, the function will also return NULL, which should
- # instead by TRUE for "all" and FALSE for "any".
- # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name
- # as a string literal.
- # The generated SQL expression for "all" is
- # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE)
- # The expression for "any" is
- # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE)
- default_value = bool(agg_func == "all")
- return lambda col: builtin("ifnull")(
- # mypy refuses to acknowledge snowflake_agg_func is non-NULL here
- snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc]
- pandas_lit(default_value),
+ if axis == 1:
+ return _generate_rowwise_aggregation_function(agg_func, agg_kwargs)
+
+ snowpark_pandas_aggregation = (
+ _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func)
+ )
+
+ if snowpark_pandas_aggregation is None:
+ # We don't have any implementation at all for this aggregation.
+ return None
+
+ snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation
+
+ if snowpark_aggregation is None:
+ # We don't have an implementation on axis=0 for this aggregation.
+ return None
+
+ # Rewrite some aggregations according to `agg_kwargs.`
+ if snowpark_aggregation == stddev or snowpark_aggregation == variance:
+ # for aggregation function std and var, we only support ddof = 0 or ddof = 1.
+ # when ddof is 1, std is mapped to stddev, var is mapped to variance
+ # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop
+ # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1
+ ddof = agg_kwargs.get("ddof", 1)
+ if ddof != 1 and ddof != 0:
+ return None
+ if ddof == 0:
+ snowpark_aggregation = (
+ stddev_pop if snowpark_aggregation == stddev else var_pop
)
- else:
- snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func)
+ elif snowpark_aggregation == column_quantile:
+ interpolation = agg_kwargs.get("interpolation", "linear")
+ q = agg_kwargs.get("q", 0.5)
+ if interpolation not in ("linear", "nearest"):
+ return None
+ if not is_scalar(q):
+ # SNOW-1062878 Because list-like q would return multiple rows, calling quantile
+ # through the aggregate frontend in this manner is unsupported.
+ return None
+
+ def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
+ return column_quantile(col, interpolation, q)
- return snowflake_agg_func
+ assert (
+ snowpark_aggregation is not None
+ ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation."
+ return SnowflakeAggFunc(
+ snowpark_aggregation=snowpark_aggregation,
+ preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
+ )
-def generate_rowwise_aggregation_function(
+def _generate_rowwise_aggregation_function(
agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any]
-) -> Optional[Callable]:
+) -> Optional[SnowflakeAggFunc]:
"""
Get a callable taking *arg columns to apply for an aggregation.
Unlike get_snowflake_agg_func, this function may return a wrapped composition of
Snowflake builtin functions depending on the values of the specified kwargs.
"""
- snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func)
- if not agg_kwargs.get("skipna", True):
- snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get(
- agg_func, snowflake_agg_func
- )
+ snowpark_pandas_aggregation = (
+ _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func)
+ )
+ if snowpark_pandas_aggregation is None:
+ return None
+ snowpark_aggregation = (
+ snowpark_pandas_aggregation.axis_1_aggregation_skipna
+ if agg_kwargs.get("skipna", True)
+ else snowpark_pandas_aggregation.axis_1_aggregation_keepna
+ )
+ if snowpark_aggregation is None:
+ return None
min_count = agg_kwargs.get("min_count", 0)
if min_count > 0:
+ original_aggregation = snowpark_aggregation
+
# Create a case statement to check if the number of non-null values exceeds min_count
# when min_count > 0, if the number of not NULL values is < min_count, return NULL.
- def agg_func_wrapper(fn: Callable) -> Callable:
- return lambda *cols: when(
- _columns_count(*cols) < min_count, pandas_lit(None)
- ).otherwise(fn(*cols))
+ def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn:
+ return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise(
+ original_aggregation(*cols)
+ )
- return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func)
- return snowflake_agg_func
+ return SnowflakeAggFunc(
+ snowpark_aggregation,
+ preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
+ )
-def is_supported_snowflake_agg_func(
- agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int
+def _is_supported_snowflake_agg_func(
+ agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1]
) -> bool:
"""
check if the aggregation function is supported with snowflake. Current supported
@@ -566,12 +699,14 @@ def is_supported_snowflake_agg_func(
is_valid: bool. Whether it is valid to implement with snowflake or not.
"""
if isinstance(agg_func, tuple) and len(agg_func) == 2:
+ # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
+ # take the second part of the named aggregation.
agg_func = agg_func[0]
return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None
-def are_all_agg_funcs_supported_by_snowflake(
- agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int
+def _are_all_agg_funcs_supported_by_snowflake(
+ agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1]
) -> bool:
"""
Check if all aggregation functions in the given list are snowflake supported
@@ -582,14 +717,14 @@ def are_all_agg_funcs_supported_by_snowflake(
return False.
"""
return all(
- is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs
+ _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs
)
def check_is_aggregation_supported_in_snowflake(
agg_func: AggFuncType,
agg_kwargs: dict[str, Any],
- axis: int,
+ axis: Literal[0, 1],
) -> bool:
"""
check if distributed implementation with snowflake is available for the aggregation
@@ -608,18 +743,18 @@ def check_is_aggregation_supported_in_snowflake(
if is_dict_like(agg_func):
return all(
(
- are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis)
+ _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis)
if is_list_like(value) and not is_named_tuple(value)
- else is_supported_snowflake_agg_func(value, agg_kwargs, axis)
+ else _is_supported_snowflake_agg_func(value, agg_kwargs, axis)
)
for value in agg_func.values()
)
elif is_list_like(agg_func):
- return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis)
- return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis)
+ return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis)
+ return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis)
-def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool:
+def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool:
"""
Is the given snowflake aggregation function needs to be applied on the numeric column.
"""
@@ -697,7 +832,7 @@ def drop_non_numeric_data_columns(
)
-def generate_aggregation_column(
+def _generate_aggregation_column(
agg_column_op_params: AggregateColumnOpParameters,
agg_kwargs: dict[str, Any],
is_groupby_agg: bool,
@@ -721,8 +856,14 @@ def generate_aggregation_column(
SnowparkColumn after the aggregation function. The column is also aliased back to the original name
"""
snowpark_column = agg_column_op_params.snowflake_quoted_identifier
- snowflake_agg_func = agg_column_op_params.snowflake_agg_func
- if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance(
+ snowflake_agg_func = agg_column_op_params.snowflake_agg_func.snowpark_aggregation
+
+ if snowflake_agg_func in (variance, var_pop) and isinstance(
+ agg_column_op_params.data_type, TimedeltaType
+ ):
+ raise TypeError("timedelta64 type does not support var operations")
+
+ if _is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance(
agg_column_op_params.data_type, BooleanType
):
# if the column is a boolean column and the aggregation function requires numeric values,
@@ -753,7 +894,7 @@ def generate_aggregation_column(
# note that we always assume keepna for array_agg. TODO(SNOW-1040398):
# make keepna treatment consistent across array_agg and other
# aggregation methods.
- agg_snowpark_column = array_agg_keepna(
+ agg_snowpark_column = _array_agg_keepna(
snowpark_column, ordering_columns=agg_column_op_params.ordering_columns
)
elif (
@@ -825,6 +966,19 @@ def generate_aggregation_column(
), f"No case expression is constructed with skipna({skipna}), min_count({min_count})"
agg_snowpark_column = case_expr.otherwise(agg_snowpark_column)
+ if (
+ isinstance(agg_column_op_params.data_type, TimedeltaType)
+ and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types
+ ):
+ # timedelta aggregations that produce timedelta results might produce
+ # a decimal type in snowflake, e.g.
+ # pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in
+ # Snowflake. We truncate the decimal part of the result, as pandas
+ # does.
+ agg_snowpark_column = cast(
+ trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type
+ )
+
# rename the column to agg_column_quoted_identifier
agg_snowpark_column = agg_snowpark_column.as_(
agg_column_op_params.agg_snowflake_quoted_identifier
@@ -857,7 +1011,7 @@ def aggregate_with_ordered_dataframe(
is_groupby_agg = groupby_columns is not None
agg_list: list[SnowparkColumn] = [
- generate_aggregation_column(
+ _generate_aggregation_column(
agg_column_op_params=agg_col_op,
agg_kwargs=agg_kwargs,
is_groupby_agg=is_groupby_agg,
@@ -973,7 +1127,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str:
)
-def generate_pandas_labels_for_agg_result_columns(
+def _generate_pandas_labels_for_agg_result_columns(
pandas_label: Hashable,
num_levels: int,
agg_func_list: list[AggFuncInfo],
@@ -1102,7 +1256,7 @@ def generate_column_agg_info(
)
# generate the pandas label and quoted identifier for the result aggregation columns, one
# for each aggregation function to apply.
- agg_col_labels = generate_pandas_labels_for_agg_result_columns(
+ agg_col_labels = _generate_pandas_labels_for_agg_result_columns(
pandas_label_to_identifier.pandas_label,
num_levels,
agg_func_list, # type: ignore[arg-type]
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py
index b58ba4f50ea..f87cdcd2e47 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py
@@ -81,7 +81,7 @@ class GroupbyApplySortMethod(Enum):
def check_return_variant_and_get_return_type(func: Callable) -> tuple[bool, DataType]:
"""Check whether the function returns a variant in Snowflake, and get its return type."""
- return_type, _ = get_types_from_type_hints(func, TempObjectType.FUNCTION)
+ return_type = deduce_return_type_from_function(func)
if return_type is None or isinstance(
return_type, (VariantType, PandasSeriesType, PandasDataFrameType)
):
@@ -390,6 +390,7 @@ def create_udtf_for_groupby_apply(
series_groupby: bool,
by_types: list[DataType],
existing_identifiers: list[str],
+ force_list_like_to_series: bool = False,
) -> UserDefinedTableFunction:
"""
Create a UDTF from the Python function for groupby.apply.
@@ -480,6 +481,7 @@ def create_udtf_for_groupby_apply(
series_groupby: Whether we are performing a SeriesGroupBy.apply() instead of DataFrameGroupBy.apply()
by_types: The snowflake types of the by columns.
existing_identifiers: List of existing column identifiers; these are omitted when creating new column identifiers.
+ force_list_like_to_series: Force the function result to series if it is list-like
Returns
-------
@@ -553,6 +555,17 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
# https://github.com/snowflakedb/snowpandas/pull/823/files#r1507286892
input_object = input_object.infer_objects()
func_result = func(input_object, *args, **kwargs)
+ if (
+ force_list_like_to_series
+ and not isinstance(func_result, native_pd.Series)
+ and native_pd.api.types.is_list_like(func_result)
+ ):
+ if len(func_result) == 1:
+ func_result = func_result[0]
+ else:
+ func_result = native_pd.Series(func_result)
+ if len(func_result) == len(df.index):
+ func_result.index = df.index
if isinstance(func_result, native_pd.Series):
if series_groupby:
func_result_as_frame = func_result.to_frame()
@@ -754,7 +767,7 @@ def __init__(self) -> None:
def convert_numpy_int_result_to_int(value: Any) -> Any:
"""
- If the result is a numpy int, convert it to a python int.
+ If the result is a numpy int (or bool), convert it to a python int (or bool.)
Use this function to make UDF results JSON-serializable. numpy ints are not
JSON-serializable, but python ints are. Note that this function cannot make
@@ -772,9 +785,14 @@ def convert_numpy_int_result_to_int(value: Any) -> Any:
Returns
-------
- int(value) if the value is a numpy int, otherwise the value.
+ int(value) if the value is a numpy int,
+ bool(value) if the value is a numpy bool, otherwise the value.
"""
- return int(value) if np.issubdtype(type(value), np.integer) else value
+ return (
+ int(value)
+ if np.issubdtype(type(value), np.integer)
+ else (bool(value) if np.issubdtype(type(value), np.bool_) else value)
+ )
def deduce_return_type_from_function(
@@ -887,7 +905,7 @@ def get_metadata_from_groupby_apply_pivot_result_column_names(
input:
get_metadata_from_groupby_apply_pivot_result_column_names([
- # this representa a data column named ('a', 'group_key') at position 0
+ # this represents a data column named ('a', 'group_key') at position 0
'"\'{""0"": ""a"", ""1"": ""group_key"", ""data_pos"": 0, ""names"": [""c1"", ""c2""]}\'"',
# this represents a data column named ('b', 'int_col') at position 1
'"\'{""0"": ""b"", ""1"": ""int_col"", ""data_pos"": 1, ""names"": [""c1"", ""c2""]}\'"',
@@ -1110,7 +1128,9 @@ def groupby_apply_pivot_result_to_final_ordered_dataframe(
# in GROUP_KEY_APPEARANCE_ORDER) and assign the
# label i to all rows that came from func(group_i).
[
- original_row_position_snowflake_quoted_identifier
+ col(original_row_position_snowflake_quoted_identifier).as_(
+ new_index_identifier
+ )
if sort_method is GroupbyApplySortMethod.ORIGINAL_ROW_ORDER
else (
dense_rank().over(
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py
index 1aa81b36e64..475fbfcefa7 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py
@@ -185,19 +185,6 @@ def compute_power_between_snowpark_columns(
return result
-def is_binary_op_supported(op: str) -> bool:
- """
- check whether binary operation is mappable to Snowflake
- Args
- op: op as string
-
- Returns:
- True if binary operation can be mapped to Snowflake/Snowpark, else False
- """
-
- return op in SUPPORTED_BINARY_OPERATIONS
-
-
def _compute_subtraction_between_snowpark_timestamp_columns(
first_operand: SnowparkColumn,
first_datatype: DataType,
@@ -312,314 +299,527 @@ def _op_is_between_timedelta_and_numeric(
)
-def compute_binary_op_between_snowpark_columns(
- op: str,
- first_operand: SnowparkColumn,
- first_datatype: DataTypeGetter,
- second_operand: SnowparkColumn,
- second_datatype: DataTypeGetter,
-) -> SnowparkPandasColumn:
- """
- Compute pandas binary operation for two SnowparkColumns
- Args:
- op: pandas operation
- first_operand: SnowparkColumn for lhs
- first_datatype: Callable for Snowpark Datatype for lhs
- second_operand: SnowparkColumn for rhs
- second_datatype: Callable for Snowpark DateType for rhs
- it is not needed.
+class BinaryOp:
+ def __init__(
+ self,
+ op: str,
+ first_operand: SnowparkColumn,
+ first_datatype: DataTypeGetter,
+ second_operand: SnowparkColumn,
+ second_datatype: DataTypeGetter,
+ ) -> None:
+ """
+ Construct a BinaryOp object to compute pandas binary operation for two SnowparkColumns
+ Args:
+ op: pandas operation
+ first_operand: SnowparkColumn for lhs
+ first_datatype: Callable for Snowpark Datatype for lhs
+ second_operand: SnowparkColumn for rhs
+ second_datatype: Callable for Snowpark DateType for rhs
+ it is not needed.
+ """
+ self.op = op
+ self.first_operand = first_operand
+ self.first_datatype = first_datatype
+ self.second_operand = second_operand
+ self.second_datatype = second_datatype
+ self.result_column = None
+ self.result_snowpark_pandas_type = None
+
+ @staticmethod
+ def is_binary_op_supported(op: str) -> bool:
+ """
+ check whether binary operation is mappable to Snowflake
+ Args
+ op: op as string
+
+ Returns:
+ True if binary operation can be mapped to Snowflake/Snowpark, else False
+ """
+
+ return op in SUPPORTED_BINARY_OPERATIONS
+
+ @staticmethod
+ def create(
+ op: str,
+ first_operand: SnowparkColumn,
+ first_datatype: DataTypeGetter,
+ second_operand: SnowparkColumn,
+ second_datatype: DataTypeGetter,
+ ) -> "BinaryOp":
+ """
+ Create a BinaryOp object to compute pandas binary operation for two SnowparkColumns
+ Args:
+ op: pandas operation
+ first_operand: SnowparkColumn for lhs
+ first_datatype: Callable for Snowpark Datatype for lhs
+ second_operand: SnowparkColumn for rhs
+ second_datatype: Callable for Snowpark DateType for rhs
+ it is not needed.
+ """
+
+ def snake_to_camel(snake_str: str) -> str:
+ """Converts a snake case string to camel case."""
+ components = snake_str.split("_")
+ return "".join(x.title() for x in components)
+
+ if op in _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP:
+ # Normalize right-sided binary operations to the equivalent left-sided
+ # operations with swapped operands. For example, rsub(col(a), col(b))
+ # becomes sub(col(b), col(a))
+ op, first_operand, first_datatype, second_operand, second_datatype = (
+ _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP[op],
+ second_operand,
+ second_datatype,
+ first_operand,
+ first_datatype,
+ )
- Returns:
- SnowparkPandasColumn for translated pandas operation
- """
- if op in _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP:
- # Normalize right-sided binary operations to the equivalent left-sided
- # operations with swapped operands. For example, rsub(col(a), col(b))
- # becomes sub(col(b), col(a))
- op, first_operand, first_datatype, second_operand, second_datatype = (
- _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP[op],
- second_operand,
- second_datatype,
- first_operand,
- first_datatype,
+ class_name = f"{snake_to_camel(op)}Op"
+ op_class = None
+ for subclass in BinaryOp.__subclasses__():
+ if subclass.__name__ == class_name:
+ op_class = subclass
+ if op_class is None:
+ op_class = BinaryOp
+ return op_class(
+ op, first_operand, first_datatype, second_operand, second_datatype
)
- binary_op_result_column = None
- snowpark_pandas_type = None
+ @staticmethod
+ def create_with_fill_value(
+ op: str,
+ lhs: SnowparkColumn,
+ lhs_datatype: DataTypeGetter,
+ rhs: SnowparkColumn,
+ rhs_datatype: DataTypeGetter,
+ fill_value: Scalar,
+ ) -> "BinaryOp":
+ """
+ Create a BinaryOp object to compute pandas binary operation for two SnowparkColumns with fill value for missing
+ values.
+
+ Args:
+ op: pandas operation
+ first_operand: SnowparkColumn for lhs
+ first_datatype: Callable for Snowpark Datatype for lhs
+ second_operand: SnowparkColumn for rhs
+ second_datatype: Callable for Snowpark DateType for rhs
+ it is not needed.
+ fill_value: the value to fill missing values
+
+ Helper method for performing binary operations.
+ 1. Fills NaN/None values in the lhs and rhs with the given fill_value.
+ 2. Computes the binary operation expression for lhs rhs.
+
+ fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs.
+ For instance, with fill_value = 100,
+ 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value.
+ result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110
+ 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value.
+ result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103
+ 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None.
+ result = lhs + rhs => None + None => None.
+
+ Args:
+ op: pandas operation to perform between lhs and rhs
+ lhs: the lhs SnowparkColumn
+ lhs_datatype: Callable for Snowpark Datatype for lhs
+ rhs: the rhs SnowparkColumn
+ rhs_datatype: Callable for Snowpark Datatype for rhs
+ fill_value: Fill existing missing (NaN) values, and any new element needed for
+ successful DataFrame alignment, with this value before computation.
+
+ Returns:
+ SnowparkPandasColumn for translated pandas operation
+ """
+ lhs_cond, rhs_cond = lhs, rhs
+ if fill_value is not None:
+ fill_value_lit = pandas_lit(fill_value)
+ lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs)
+ rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs)
+
+ return BinaryOp.create(op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype)
+
+ @staticmethod
+ def create_with_rhs_scalar(
+ op: str,
+ first_operand: SnowparkColumn,
+ datatype: DataTypeGetter,
+ second_operand: Scalar,
+ ) -> "BinaryOp":
+ """
+ Compute the binary operation between a Snowpark column and a scalar.
+ Args:
+ op: the name of binary operation
+ first_operand: The SnowparkColumn for lhs
+ datatype: Callable for Snowpark data type
+ second_operand: Scalar value
+
+ Returns:
+ SnowparkPandasColumn for translated pandas operation
+ """
+
+ def second_datatype() -> DataType:
+ return infer_object_type(second_operand)
+
+ return BinaryOp.create(
+ op, first_operand, datatype, pandas_lit(second_operand), second_datatype
+ )
- # some operators and the data types have to be handled specially to align with pandas
- # However, it is difficult to fail early if the arithmetic operator is not compatible
- # with the data type, so we just let the server raise exception (e.g. a string minus a string).
- if (
- op == "add"
- and isinstance(second_datatype(), TimedeltaType)
- and isinstance(first_datatype(), TimestampType)
- ):
- binary_op_result_column = dateadd("ns", second_operand, first_operand)
- elif (
- op == "add"
- and isinstance(first_datatype(), TimedeltaType)
- and isinstance(second_datatype(), TimestampType)
- ):
- binary_op_result_column = dateadd("ns", first_operand, second_operand)
- elif op in (
- "add",
- "sub",
- "eq",
- "ne",
- "gt",
- "ge",
- "lt",
- "le",
- "floordiv",
- "truediv",
- ) and (
- (
- isinstance(first_datatype(), TimedeltaType)
- and isinstance(second_datatype(), NullType)
+ @staticmethod
+ def create_with_lhs_scalar(
+ op: str,
+ first_operand: Scalar,
+ second_operand: SnowparkColumn,
+ datatype: DataTypeGetter,
+ ) -> "BinaryOp":
+ """
+ Compute the binary operation between a scalar and a Snowpark column.
+ Args:
+ op: the name of binary operation
+ first_operand: Scalar value
+ second_operand: The SnowparkColumn for rhs
+ datatype: Callable for Snowpark data type
+ it is not needed.
+
+ Returns:
+ SnowparkPandasColumn for translated pandas operation
+ """
+
+ def first_datatype() -> DataType:
+ return infer_object_type(first_operand)
+
+ return BinaryOp.create(
+ op, pandas_lit(first_operand), first_datatype, second_operand, datatype
)
- or (
- isinstance(second_datatype(), TimedeltaType)
- and isinstance(first_datatype(), NullType)
+
+ def _custom_compute(self) -> None:
+ """Implement custom compute method if needed."""
+ pass
+
+ def _get_result(self) -> SnowparkPandasColumn:
+ return SnowparkPandasColumn(
+ snowpark_column=self.result_column,
+ snowpark_pandas_type=self.result_snowpark_pandas_type,
)
- ):
- return SnowparkPandasColumn(pandas_lit(None), TimedeltaType())
- elif (
- op == "sub"
- and isinstance(second_datatype(), TimedeltaType)
- and isinstance(first_datatype(), TimestampType)
- ):
- binary_op_result_column = dateadd("ns", -1 * second_operand, first_operand)
- elif (
- op == "sub"
- and isinstance(first_datatype(), TimedeltaType)
- and isinstance(second_datatype(), TimestampType)
- ):
+
+ def _check_timedelta_with_none(self) -> None:
+ if self.op in (
+ "add",
+ "sub",
+ "eq",
+ "ne",
+ "gt",
+ "ge",
+ "lt",
+ "le",
+ "floordiv",
+ "truediv",
+ ) and (
+ (
+ isinstance(self.first_datatype(), TimedeltaType)
+ and isinstance(self.second_datatype(), NullType)
+ )
+ or (
+ isinstance(self.second_datatype(), TimedeltaType)
+ and isinstance(self.first_datatype(), NullType)
+ )
+ ):
+ self.result_column = pandas_lit(None)
+ self.result_snowpark_pandas_type = TimedeltaType()
+
+ def _check_error(self) -> None:
# Timedelta - Timestamp doesn't make sense. Raise the same error
# message as pandas.
- raise TypeError("bad operand type for unary -: 'DatetimeArray'")
- elif op == "mod" and _op_is_between_two_timedeltas_or_timedelta_and_null(
- first_datatype(), second_datatype()
- ):
- binary_op_result_column = compute_modulo_between_snowpark_columns(
- first_operand, first_datatype(), second_operand, second_datatype()
- )
- snowpark_pandas_type = TimedeltaType()
- elif op == "pow" and _op_is_between_two_timedeltas_or_timedelta_and_null(
- first_datatype(), second_datatype()
- ):
- raise TypeError("unsupported operand type for **: Timedelta")
- elif op == "__or__" and _op_is_between_two_timedeltas_or_timedelta_and_null(
- first_datatype(), second_datatype()
- ):
- raise TypeError("unsupported operand type for |: Timedelta")
- elif op == "__and__" and _op_is_between_two_timedeltas_or_timedelta_and_null(
- first_datatype(), second_datatype()
- ):
- raise TypeError("unsupported operand type for &: Timedelta")
- elif (
- op in ("add", "sub")
- and isinstance(first_datatype(), TimedeltaType)
- and isinstance(second_datatype(), TimedeltaType)
- ):
- snowpark_pandas_type = TimedeltaType()
- elif op == "mul" and _op_is_between_two_timedeltas_or_timedelta_and_null(
- first_datatype(), second_datatype()
- ):
- raise np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined]
- np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]"))
- )
- elif op in (
- "eq",
- "ne",
- "gt",
- "ge",
- "lt",
- "le",
- ) and _op_is_between_two_timedeltas_or_timedelta_and_null(
- first_datatype(), second_datatype()
- ):
- # These operations, when done between timedeltas, work without any
- # extra handling in `snowpark_pandas_type` or `binary_op_result_column`.
- pass
- elif op == "mul" and (
- _op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
- ):
- binary_op_result_column = cast(
- floor(first_operand * second_operand), LongType()
- )
- snowpark_pandas_type = TimedeltaType()
- # For `eq` and `ne`, note that Snowflake will consider 1 equal to
- # Timedelta(1) because those two have the same representation in Snowflake,
- # so we have to compare types in the client.
- elif op == "eq" and (
- _op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
- ):
- binary_op_result_column = pandas_lit(False)
- elif op == "ne" and _op_is_between_timedelta_and_numeric(
- first_datatype, second_datatype
- ):
- binary_op_result_column = pandas_lit(True)
- elif (
- op in ("truediv", "floordiv")
- and isinstance(first_datatype(), TimedeltaType)
- and _is_numeric_non_timedelta_type(second_datatype())
- ):
- binary_op_result_column = cast(
- floor(first_operand / second_operand), LongType()
- )
- snowpark_pandas_type = TimedeltaType()
- elif (
- op == "mod"
- and isinstance(first_datatype(), TimedeltaType)
- and _is_numeric_non_timedelta_type(second_datatype())
- ):
- binary_op_result_column = ceil(
- compute_modulo_between_snowpark_columns(
- first_operand, first_datatype(), second_operand, second_datatype()
+ if (
+ self.op == "sub"
+ and isinstance(self.first_datatype(), TimedeltaType)
+ and isinstance(self.second_datatype(), TimestampType)
+ ):
+ raise TypeError("bad operand type for unary -: 'DatetimeArray'")
+
+ # Raise error for two timedelta or timedelta and null
+ two_timedeltas_or_timedelta_and_null_error = {
+ "pow": TypeError("unsupported operand type for **: Timedelta"),
+ "__or__": TypeError("unsupported operand type for |: Timedelta"),
+ "__and__": TypeError("unsupported operand type for &: Timedelta"),
+ "mul": np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined]
+ np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]"))
+ ),
+ }
+ if (
+ self.op in two_timedeltas_or_timedelta_and_null_error
+ and _op_is_between_two_timedeltas_or_timedelta_and_null(
+ self.first_datatype(), self.second_datatype()
)
- )
- snowpark_pandas_type = TimedeltaType()
- elif op in ("add", "sub") and (
- (
- isinstance(first_datatype(), TimedeltaType)
- and _is_numeric_non_timedelta_type(second_datatype())
- )
- or (
- _is_numeric_non_timedelta_type(first_datatype())
- and isinstance(second_datatype(), TimedeltaType)
- )
- ):
- raise TypeError(
- "Snowpark pandas does not support addition or subtraction between timedelta values and numeric values."
- )
- elif op in ("truediv", "floordiv", "mod") and (
- _is_numeric_non_timedelta_type(first_datatype())
- and isinstance(second_datatype(), TimedeltaType)
- ):
- raise TypeError(
- "Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), or floordiv (//)."
- )
- elif op in (
- "add",
- "sub",
- "truediv",
- "floordiv",
- "mod",
- "gt",
- "ge",
- "lt",
- "le",
- "ne",
- "eq",
- ) and (
- (
- isinstance(first_datatype(), TimedeltaType)
- and isinstance(second_datatype(), StringType)
- )
- or (
- isinstance(second_datatype(), TimedeltaType)
- and isinstance(first_datatype(), StringType)
- )
- ):
+ ):
+ raise two_timedeltas_or_timedelta_and_null_error[self.op]
+
+ if self.op in ("add", "sub") and (
+ (
+ isinstance(self.first_datatype(), TimedeltaType)
+ and _is_numeric_non_timedelta_type(self.second_datatype())
+ )
+ or (
+ _is_numeric_non_timedelta_type(self.first_datatype())
+ and isinstance(self.second_datatype(), TimedeltaType)
+ )
+ ):
+ raise TypeError(
+ "Snowpark pandas does not support addition or subtraction between timedelta values and numeric values."
+ )
+
+ if self.op in ("truediv", "floordiv", "mod") and (
+ _is_numeric_non_timedelta_type(self.first_datatype())
+ and isinstance(self.second_datatype(), TimedeltaType)
+ ):
+ raise TypeError(
+ "Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), "
+ "or floordiv (//)."
+ )
+
# TODO(SNOW-1646604): Support these cases.
- ErrorMessage.not_implemented(
- f"Snowpark pandas does not yet support the operation {op} between timedelta and string"
- )
- elif op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and (
- _op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
- ):
- raise TypeError(
- f"Snowpark pandas does not support binary operation {op} between timedelta and a non-timedelta type."
- )
- elif op == "floordiv":
- binary_op_result_column = floor(first_operand / second_operand)
- elif op == "mod":
- binary_op_result_column = compute_modulo_between_snowpark_columns(
- first_operand, first_datatype(), second_operand, second_datatype()
- )
- elif op == "pow":
- binary_op_result_column = compute_power_between_snowpark_columns(
- first_operand, second_operand
- )
- elif op == "__or__":
- binary_op_result_column = first_operand | second_operand
- elif op == "__and__":
- binary_op_result_column = first_operand & second_operand
- elif (
- op == "add"
- and isinstance(second_datatype(), StringType)
- and isinstance(first_datatype(), StringType)
- ):
- # string/string case (only for add)
- binary_op_result_column = concat(first_operand, second_operand)
- elif op == "mul" and (
- (
- isinstance(second_datatype(), _IntegralType)
- and isinstance(first_datatype(), StringType)
- )
- or (
- isinstance(second_datatype(), StringType)
- and isinstance(first_datatype(), _IntegralType)
+ if self.op in (
+ "add",
+ "sub",
+ "truediv",
+ "floordiv",
+ "mod",
+ "gt",
+ "ge",
+ "lt",
+ "le",
+ "ne",
+ "eq",
+ ) and (
+ (
+ isinstance(self.first_datatype(), TimedeltaType)
+ and isinstance(self.second_datatype(), StringType)
+ )
+ or (
+ isinstance(self.second_datatype(), TimedeltaType)
+ and isinstance(self.first_datatype(), StringType)
+ )
+ ):
+ ErrorMessage.not_implemented(
+ f"Snowpark pandas does not yet support the operation {self.op} between timedelta and string"
+ )
+
+ if self.op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and (
+ _op_is_between_timedelta_and_numeric(
+ self.first_datatype, self.second_datatype
+ )
+ ):
+ raise TypeError(
+ f"Snowpark pandas does not support binary operation {self.op} between timedelta and a non-timedelta "
+ f"type."
+ )
+
+ def compute(self) -> SnowparkPandasColumn:
+ self._check_error()
+
+ self._check_timedelta_with_none()
+
+ if self.result_column is not None:
+ return self._get_result()
+
+ # Generally, some operators and the data types have to be handled specially to align with pandas
+ # However, it is difficult to fail early if the arithmetic operator is not compatible
+ # with the data type, so we just let the server raise exception (e.g. a string minus a string).
+
+ self._custom_compute()
+ if self.result_column is None:
+ # If there is no special binary_op_result_column result, it means the operator and
+ # the data type of the column don't need special handling. Then we get the overloaded
+ # operator from Snowpark Column class, e.g., __add__ to perform binary operations.
+ self.result_column = getattr(self.first_operand, f"__{self.op}__")(
+ self.second_operand
+ )
+
+ return self._get_result()
+
+
+class AddOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ if isinstance(self.second_datatype(), TimedeltaType) and isinstance(
+ self.first_datatype(), TimestampType
+ ):
+ self.result_column = dateadd("ns", self.second_operand, self.first_operand)
+ elif isinstance(self.first_datatype(), TimedeltaType) and isinstance(
+ self.second_datatype(), TimestampType
+ ):
+ self.result_column = dateadd("ns", self.first_operand, self.second_operand)
+ elif isinstance(self.first_datatype(), TimedeltaType) and isinstance(
+ self.second_datatype(), TimedeltaType
+ ):
+ self.result_snowpark_pandas_type = TimedeltaType()
+ elif isinstance(self.second_datatype(), StringType) and isinstance(
+ self.first_datatype(), StringType
+ ):
+ # string/string case (only for add)
+ self.result_column = concat(self.first_operand, self.second_operand)
+
+
+class SubOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ if isinstance(self.second_datatype(), TimedeltaType) and isinstance(
+ self.first_datatype(), TimestampType
+ ):
+ self.result_column = dateadd(
+ "ns", -1 * self.second_operand, self.first_operand
+ )
+ elif isinstance(self.first_datatype(), TimedeltaType) and isinstance(
+ self.second_datatype(), TimedeltaType
+ ):
+ self.result_snowpark_pandas_type = TimedeltaType()
+ elif isinstance(self.first_datatype(), TimestampType) and isinstance(
+ self.second_datatype(), NullType
+ ):
+ # Timestamp - NULL or NULL - Timestamp raises SQL compilation error,
+ # but it's valid in pandas and returns NULL.
+ self.result_column = pandas_lit(None)
+ elif isinstance(self.first_datatype(), NullType) and isinstance(
+ self.second_datatype(), TimestampType
+ ):
+ # Timestamp - NULL or NULL - Timestamp raises SQL compilation error,
+ # but it's valid in pandas and returns NULL.
+ self.result_column = pandas_lit(None)
+ elif isinstance(self.first_datatype(), TimestampType) and isinstance(
+ self.second_datatype(), TimestampType
+ ):
+ (
+ self.result_column,
+ self.result_snowpark_pandas_type,
+ ) = _compute_subtraction_between_snowpark_timestamp_columns(
+ first_operand=self.first_operand,
+ first_datatype=self.first_datatype(),
+ second_operand=self.second_operand,
+ second_datatype=self.second_datatype(),
+ )
+
+
+class ModOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ self.result_column = compute_modulo_between_snowpark_columns(
+ self.first_operand,
+ self.first_datatype(),
+ self.second_operand,
+ self.second_datatype(),
)
- ):
- # string/integer case (only for mul/rmul).
- # swap first_operand with second_operand because
- # REPEAT(, ) expects to be string
- if isinstance(first_datatype(), _IntegralType):
- first_operand, second_operand = second_operand, first_operand
-
- binary_op_result_column = iff(
- second_operand > pandas_lit(0),
- repeat(first_operand, second_operand),
- # Snowflake's repeat doesn't support negative number,
- # but pandas will return an empty string
- pandas_lit(""),
+ if _op_is_between_two_timedeltas_or_timedelta_and_null(
+ self.first_datatype(), self.second_datatype()
+ ):
+ self.result_snowpark_pandas_type = TimedeltaType()
+ elif isinstance(
+ self.first_datatype(), TimedeltaType
+ ) and _is_numeric_non_timedelta_type(self.second_datatype()):
+ self.result_column = ceil(self.result_column)
+ self.result_snowpark_pandas_type = TimedeltaType()
+
+
+class MulOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ if _op_is_between_timedelta_and_numeric(
+ self.first_datatype, self.second_datatype
+ ):
+ self.result_column = cast(
+ floor(self.first_operand * self.second_operand), LongType()
+ )
+ self.result_snowpark_pandas_type = TimedeltaType()
+ elif (
+ isinstance(self.second_datatype(), _IntegralType)
+ and isinstance(self.first_datatype(), StringType)
+ ) or (
+ isinstance(self.second_datatype(), StringType)
+ and isinstance(self.first_datatype(), _IntegralType)
+ ):
+ # string/integer case (only for mul/rmul).
+ # swap first_operand with second_operand because
+ # REPEAT(, ) expects to be string
+ if isinstance(self.first_datatype(), _IntegralType):
+ self.first_operand, self.second_operand = (
+ self.second_operand,
+ self.first_operand,
+ )
+
+ self.result_column = iff(
+ self.second_operand > pandas_lit(0),
+ repeat(self.first_operand, self.second_operand),
+ # Snowflake's repeat doesn't support negative number,
+ # but pandas will return an empty string
+ pandas_lit(""),
+ )
+
+
+class EqOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ # For `eq` and `ne`, note that Snowflake will consider 1 equal to
+ # Timedelta(1) because those two have the same representation in Snowflake,
+ # so we have to compare types in the client.
+ if _op_is_between_timedelta_and_numeric(
+ self.first_datatype, self.second_datatype
+ ):
+ self.result_column = pandas_lit(False)
+
+
+class NeOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ # For `eq` and `ne`, note that Snowflake will consider 1 equal to
+ # Timedelta(1) because those two have the same representation in Snowflake,
+ # so we have to compare types in the client.
+ if _op_is_between_timedelta_and_numeric(
+ self.first_datatype, self.second_datatype
+ ):
+ self.result_column = pandas_lit(True)
+
+
+class FloordivOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ self.result_column = floor(self.first_operand / self.second_operand)
+ if isinstance(
+ self.first_datatype(), TimedeltaType
+ ) and _is_numeric_non_timedelta_type(self.second_datatype()):
+ self.result_column = cast(self.result_column, LongType())
+ self.result_snowpark_pandas_type = TimedeltaType()
+
+
+class TruedivOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ if isinstance(
+ self.first_datatype(), TimedeltaType
+ ) and _is_numeric_non_timedelta_type(self.second_datatype()):
+ self.result_column = cast(
+ floor(self.first_operand / self.second_operand), LongType()
+ )
+ self.result_snowpark_pandas_type = TimedeltaType()
+
+
+class PowOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ self.result_column = compute_power_between_snowpark_columns(
+ self.first_operand, self.second_operand
)
- elif op == "equal_null":
+
+
+class OrOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ self.result_column = self.first_operand | self.second_operand
+
+
+class AndOp(BinaryOp):
+ def _custom_compute(self) -> None:
+ self.result_column = self.first_operand & self.second_operand
+
+
+class EqualNullOp(BinaryOp):
+ def _custom_compute(self) -> None:
# TODO(SNOW-1641716): In Snowpark pandas, generally use this equal_null
# with type checking intead of snowflake.snowpark.functions.equal_null.
- if not are_equal_types(first_datatype(), second_datatype()):
- binary_op_result_column = pandas_lit(False)
+ if not are_equal_types(self.first_datatype(), self.second_datatype()):
+ self.result_column = pandas_lit(False)
else:
- binary_op_result_column = first_operand.equal_null(second_operand)
- elif (
- op == "sub"
- and isinstance(first_datatype(), TimestampType)
- and isinstance(second_datatype(), NullType)
- ):
- # Timestamp - NULL or NULL - Timestamp raises SQL compilation error,
- # but it's valid in pandas and returns NULL.
- binary_op_result_column = pandas_lit(None)
- elif (
- op == "sub"
- and isinstance(first_datatype(), NullType)
- and isinstance(second_datatype(), TimestampType)
- ):
- # Timestamp - NULL or NULL - Timestamp raises SQL compilation error,
- # but it's valid in pandas and returns NULL.
- binary_op_result_column = pandas_lit(None)
- elif (
- op == "sub"
- and isinstance(first_datatype(), TimestampType)
- and isinstance(second_datatype(), TimestampType)
- ):
- return _compute_subtraction_between_snowpark_timestamp_columns(
- first_operand=first_operand,
- first_datatype=first_datatype(),
- second_operand=second_operand,
- second_datatype=second_datatype(),
- )
- # If there is no special binary_op_result_column result, it means the operator and
- # the data type of the column don't need special handling. Then we get the overloaded
- # operator from Snowpark Column class, e.g., __add__ to perform binary operations.
- if binary_op_result_column is None:
- binary_op_result_column = getattr(first_operand, f"__{op}__")(second_operand)
-
- return SnowparkPandasColumn(
- snowpark_column=binary_op_result_column,
- snowpark_pandas_type=snowpark_pandas_type,
- )
+ self.result_column = self.first_operand.equal_null(self.second_operand)
def are_equal_types(type1: DataType, type2: DataType) -> bool:
@@ -644,104 +844,6 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool:
return type1 == type2
-def compute_binary_op_between_snowpark_column_and_scalar(
- op: str,
- first_operand: SnowparkColumn,
- datatype: DataTypeGetter,
- second_operand: Scalar,
-) -> SnowparkPandasColumn:
- """
- Compute the binary operation between a Snowpark column and a scalar.
- Args:
- op: the name of binary operation
- first_operand: The SnowparkColumn for lhs
- datatype: Callable for Snowpark data type
- second_operand: Scalar value
-
- Returns:
- SnowparkPandasColumn for translated pandas operation
- """
-
- def second_datatype() -> DataType:
- return infer_object_type(second_operand)
-
- return compute_binary_op_between_snowpark_columns(
- op, first_operand, datatype, pandas_lit(second_operand), second_datatype
- )
-
-
-def compute_binary_op_between_scalar_and_snowpark_column(
- op: str,
- first_operand: Scalar,
- second_operand: SnowparkColumn,
- datatype: DataTypeGetter,
-) -> SnowparkPandasColumn:
- """
- Compute the binary operation between a scalar and a Snowpark column.
- Args:
- op: the name of binary operation
- first_operand: Scalar value
- second_operand: The SnowparkColumn for rhs
- datatype: Callable for Snowpark data type
- it is not needed.
-
- Returns:
- SnowparkPandasColumn for translated pandas operation
- """
-
- def first_datatype() -> DataType:
- return infer_object_type(first_operand)
-
- return compute_binary_op_between_snowpark_columns(
- op, pandas_lit(first_operand), first_datatype, second_operand, datatype
- )
-
-
-def compute_binary_op_with_fill_value(
- op: str,
- lhs: SnowparkColumn,
- lhs_datatype: DataTypeGetter,
- rhs: SnowparkColumn,
- rhs_datatype: DataTypeGetter,
- fill_value: Scalar,
-) -> SnowparkPandasColumn:
- """
- Helper method for performing binary operations.
- 1. Fills NaN/None values in the lhs and rhs with the given fill_value.
- 2. Computes the binary operation expression for lhs rhs.
-
- fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs.
- For instance, with fill_value = 100,
- 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value.
- result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110
- 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value.
- result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103
- 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None.
- result = lhs + rhs => None + None => None.
-
- Args:
- op: pandas operation to perform between lhs and rhs
- lhs: the lhs SnowparkColumn
- lhs_datatype: Callable for Snowpark Datatype for lhs
- rhs: the rhs SnowparkColumn
- rhs_datatype: Callable for Snowpark Datatype for rhs
- fill_value: Fill existing missing (NaN) values, and any new element needed for
- successful DataFrame alignment, with this value before computation.
-
- Returns:
- SnowparkPandasColumn for translated pandas operation
- """
- lhs_cond, rhs_cond = lhs, rhs
- if fill_value is not None:
- fill_value_lit = pandas_lit(fill_value)
- lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs)
- rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs)
-
- return compute_binary_op_between_snowpark_columns(
- op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype
- )
-
-
def merge_label_and_identifier_pairs(
sorted_column_labels: list[str],
q_frame_sorted: list[tuple[str, str]],
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py
index 882dc79d2a8..4eaf98d9b29 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py
@@ -189,8 +189,6 @@ def compute_bin_indices(
values_frame,
cuts_frame,
how="asof",
- left_on=[],
- right_on=[],
left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0],
right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0],
match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py
index c2c224e404c..6207bd2399a 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py
@@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame(
key,
count_frame,
"cross",
- left_on=[],
- right_on=[],
inherit_join_index=InheritJoinIndex.FROM_LEFT,
)
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
index 457bd388f2b..d07211dbcf5 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
@@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple):
result_column_mapper: JoinOrAlignResultColumnMapper
+def assert_snowpark_pandas_types_match(
+ left: InternalFrame,
+ right: InternalFrame,
+ left_join_identifiers: list[str],
+ right_join_identifiers: list[str],
+) -> None:
+ """
+ If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised.
+
+ Args:
+ left: An internal frame to use on left side of join.
+ right: An internal frame to use on right side of join.
+ left_join_identifiers: List of snowflake identifiers to check types from 'left' frame.
+ right_join_identifiers: List of snowflake identifiers to check types from 'right' frame.
+ left_identifiers and right_identifiers must be lists of equal length.
+
+ Returns: None
+
+ Raises: ValueError
+ """
+ left_types = [
+ left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
+ for id in left_join_identifiers
+ ]
+ right_types = [
+ right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
+ for id in right_join_identifiers
+ ]
+ for i, (lt, rt) in enumerate(zip(left_types, right_types)):
+ if lt != rt:
+ left_on_id = left_join_identifiers[i]
+ idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
+ key = left.data_column_pandas_labels[idx]
+ lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
+ rt = (
+ rt
+ if rt is not None
+ else right.get_snowflake_type(right_join_identifiers[i])
+ )
+ raise ValueError(
+ f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
+ f"If you wish to proceed you should use pd.concat"
+ )
+
+
def join(
left: InternalFrame,
right: InternalFrame,
how: JoinTypeLit,
- left_on: list[str],
- right_on: list[str],
+ left_on: Optional[list[str]] = None,
+ right_on: Optional[list[str]] = None,
left_match_col: Optional[str] = None,
right_match_col: Optional[str] = None,
match_comparator: Optional[MatchComparator] = None,
@@ -161,40 +206,48 @@ def join(
include mapping for index + data columns, ordering columns and row position column
if exists.
"""
- assert len(left_on) == len(
- right_on
- ), "left_on and right_on must be of same length or both be None"
- if join_key_coalesce_config is not None:
- assert len(join_key_coalesce_config) == len(
- left_on
- ), "join_key_coalesce_config must be of same length as left_on and right_on"
assert how in get_args(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"
- def assert_snowpark_pandas_types_match() -> None:
- """If Snowpark pandas types do not match, then a ValueError will be raised."""
- left_types = [
- left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
- for id in left_on
- ]
- right_types = [
- right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
- for id in right_on
- ]
- for i, (lt, rt) in enumerate(zip(left_types, right_types)):
- if lt != rt:
- left_on_id = left_on[i]
- idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
- key = left.data_column_pandas_labels[idx]
- lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
- rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
- raise ValueError(
- f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
- f"If you wish to proceed you should use pd.concat"
- )
+ left_on = left_on or []
+ right_on = right_on or []
+ assert len(left_on) == len(
+ right_on
+ ), "left_on and right_on must be of same length or both be None"
- assert_snowpark_pandas_types_match()
+ if how == "asof":
+ assert (
+ left_match_col
+ ), "ASOF join was not provided a column identifier to match on for the left table"
+ assert (
+ right_match_col
+ ), "ASOF join was not provided a column identifier to match on for the right table"
+ assert (
+ match_comparator
+ ), "ASOF join was not provided a comparator for the match condition"
+ left_join_key = [left_match_col]
+ right_join_key = [right_match_col]
+ left_join_key.extend(left_on)
+ right_join_key.extend(right_on)
+ if join_key_coalesce_config is not None:
+ assert len(join_key_coalesce_config) == len(
+ left_join_key
+ ), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key"
+ else:
+ left_join_key = left_on
+ right_join_key = right_on
+ assert (
+ left_match_col is None
+ and right_match_col is None
+ and match_comparator is None
+ ), f"match condition should not be provided for {how} join"
+ if join_key_coalesce_config is not None:
+ assert len(join_key_coalesce_config) == len(
+ left_join_key
+ ), "join_key_coalesce_config must be of same length as left_on and right_on"
+
+ assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key)
# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
@@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None:
match_comparator=match_comparator,
how=how,
)
-
return _create_internal_frame_with_join_or_align_result(
joined_ordered_dataframe,
left,
right,
how,
- left_on,
- right_on,
+ left_join_key,
+ right_join_key,
sort,
join_key_coalesce_config,
inherit_join_index,
@@ -1075,7 +1127,7 @@ def join_on_index_columns(
Returns:
An InternalFrame for the joined result.
- A JoinOrAlignResultColumnMapper that provides quited identifiers mapping from the
+ A JoinOrAlignResultColumnMapper that provides quoted identifiers mapping from the
original left and right dataframe to the joined dataframe, it is guaranteed to
include mapping for index + data columns, ordering columns and row position column
if exists.
@@ -1263,7 +1315,7 @@ def align_on_index(
* outer: use union of index from both frames, sort index lexicographically.
Returns:
An InternalFrame for the aligned result.
- A JoinOrAlignResultColumnMapper that provides quited identifiers mapping from the
+ A JoinOrAlignResultColumnMapper that provides quoted identifiers mapping from the
original left and right dataframe to the aligned dataframe, it is guaranteed to
include mapping for index + data columns, ordering columns and row position column
if exists.
@@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None:
)
elif self._how == "right":
ordering_column_identifiers = mapped_right_on
+ elif self._how == "asof":
+ # Order only by the left match_condition column
+ ordering_column_identifiers = [mapped_left_on[0]]
else: # left join, inner join, left align, coalesce align
ordering_column_identifiers = mapped_left_on
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
index f7ae87c2a5d..91537d98e30 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
@@ -1197,22 +1197,29 @@ def join(
# get the new mapped right on identifier
right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols]
- # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
- on = None
- for left_col, right_col in zip(left_on_cols, right_on_cols):
- eq = Column(left_col).equal_null(Column(right_col))
- on = eq if on is None else on & eq
-
if how == "asof":
- assert left_match_col, "left_match_col was not provided to ASOF Join"
+ assert (
+ left_match_col
+ ), "ASOF join was not provided a column identifier to match on for the left table"
left_match_col = Column(left_match_col)
# Get the new mapped right match condition identifier
- assert right_match_col, "right_match_col was not provided to ASOF Join"
+ assert (
+ right_match_col
+ ), "ASOF join was not provided a column identifier to match on for the right table"
right_match_col = Column(right_identifiers_rename_map[right_match_col])
# ASOF Join requires the use of match_condition
- assert match_comparator, "match_comparator was not provided to ASOF Join"
+ assert (
+ match_comparator
+ ), "ASOF join was not provided a comparator for the match condition"
+
+ on = None
+ for left_col, right_col in zip(left_on_cols, right_on_cols):
+ eq = Column(left_col).__eq__(Column(right_col))
+ on = eq if on is None else on & eq
+
snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right=right_snowpark_dataframe_ref.snowpark_dataframe,
+ on=on,
how=how,
match_condition=getattr(left_match_col, match_comparator.value)(
right_match_col
@@ -1224,6 +1231,12 @@ def join(
right_snowpark_dataframe_ref.snowpark_dataframe, how=how
)
else:
+ # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
+ on = None
+ for left_col, right_col in zip(left_on_cols, right_on_cols):
+ eq = Column(left_col).equal_null(Column(right_col))
+ on = eq if on is None else on & eq
+
snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right_snowpark_dataframe_ref.snowpark_dataframe, on, how
)
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py
index 3bf1062107e..e7a96b49ef1 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py
@@ -520,12 +520,15 @@ def single_pivot_helper(
data_column_snowflake_quoted_identifiers: new data column snowflake quoted identifiers this pivot result
data_column_pandas_labels: new data column pandas labels for this pivot result
"""
- snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {})
- if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func):
+ snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0)
+ if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func(
+ snowflake_agg_func.snowpark_aggregation
+ ):
# TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations
raise ErrorMessage.not_implemented(
f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments."
)
+ snowpark_aggr_func = snowflake_agg_func.snowpark_aggregation
pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair
@@ -1231,17 +1234,19 @@ def get_margin_aggregation(
Returns:
Snowpark column expression for the aggregation function result.
"""
- resolved_aggfunc = get_snowflake_agg_func(aggfunc, {})
+ resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}, axis=0)
# This would have been resolved during the original pivot at an early stage.
assert resolved_aggfunc is not None, "resolved_aggfunc is None"
- aggfunc_expr = resolved_aggfunc(snowflake_quoted_identifier)
+ aggregation_expression = resolved_aggfunc.snowpark_aggregation(
+ snowflake_quoted_identifier
+ )
- if resolved_aggfunc == sum_:
- aggfunc_expr = coalesce(aggfunc_expr, pandas_lit(0))
+ if resolved_aggfunc.snowpark_aggregation == sum_:
+ aggregation_expression = coalesce(aggregation_expression, pandas_lit(0))
- return aggfunc_expr
+ return aggregation_expression
def expand_pivot_result_with_pivot_table_margins_no_groupby_columns(
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py
index de83e0429bf..ba8ceedec5e 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py
@@ -649,8 +649,6 @@ def perform_asof_join_on_frame(
left=preserving_frame,
right=referenced_frame,
how="asof",
- left_on=[],
- right_on=[],
left_match_col=left_timecol_snowflake_quoted_identifier,
right_match_col=right_timecol_snowflake_quoted_identifier,
match_comparator=(
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py
index d38584c14de..e19a6de37ba 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py
@@ -567,9 +567,7 @@ def __new__(
attrs (Dict[str, Any]): The attributes of the class.
Returns:
- Union[snowflake.snowpark.modin.pandas.series.Series,
- snowflake.snowpark.modin.pandas.dataframe.DataFrame,
- snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy,
+ Union[snowflake.snowpark.modin.pandas.groupby.DataFrameGroupBy,
snowflake.snowpark.modin.pandas.resample.Resampler,
snowflake.snowpark.modin.pandas.window.Window,
snowflake.snowpark.modin.pandas.window.Rolling]:
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py
index 0242177d1f0..3b714087535 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py
@@ -22,9 +22,17 @@
cast,
convert_timezone,
date_part,
+ dayofmonth,
+ hour,
iff,
+ minute,
+ month,
+ second,
+ timestamp_tz_from_parts,
to_decimal,
+ to_timestamp_ntz,
trunc,
+ year,
)
from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit
from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
@@ -467,3 +475,60 @@ def convert_dateoffset_to_interval(
)
interval_kwargs[new_param] = offset
return Interval(**interval_kwargs)
+
+
+def tz_localize_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column:
+ """
+ Localize tz-naive to tz-aware.
+ Args:
+ tz : str, pytz.timezone, optional
+ Localize a tz-naive datetime column to tz-aware
+
+ Args:
+ column: the Snowpark datetime column
+ tz: time zone for time. Corresponding timestamps would be converted to this time zone of the Datetime Array/Index. A tz of None will convert to UTC and remove the timezone information.
+
+ Returns:
+ The column after tz localization
+ """
+ if tz is None:
+ # If this column is already a TIMESTAMP_NTZ, this cast does nothing.
+ # If the column is a TIMESTAMP_TZ, the cast drops the timezone and converts
+ # to TIMESTAMP_NTZ.
+ return to_timestamp_ntz(column)
+ else:
+ if isinstance(tz, dt.tzinfo):
+ tz_name = tz.tzname(None)
+ else:
+ tz_name = tz
+ return timestamp_tz_from_parts(
+ year(column),
+ month(column),
+ dayofmonth(column),
+ hour(column),
+ minute(column),
+ second(column),
+ date_part("nanosecond", column),
+ pandas_lit(tz_name),
+ )
+
+
+def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column:
+ """
+ Converts a datetime column to the specified timezone
+
+ Args:
+ column: the Snowpark datetime column
+ tz: the target timezone
+
+ Returns:
+ The column after conversion to the specified timezone
+ """
+ if tz is None:
+ return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column))
+ else:
+ if isinstance(tz, dt.tzinfo):
+ tz_name = tz.tzname(None)
+ else:
+ tz_name = tz
+ return convert_timezone(pandas_lit(tz_name), column)
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py
index 5656bbfb14a..34a3376fcc1 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py
@@ -41,6 +41,7 @@
mean,
min as min_,
sum as sum_,
+ to_char,
to_timestamp_ntz,
to_timestamp_tz,
typeof,
@@ -75,6 +76,8 @@
StringType,
StructField,
StructType,
+ TimestampTimeZone,
+ TimestampType,
VariantType,
_FractionalType,
)
@@ -1273,7 +1276,7 @@ def check_snowpark_pandas_object_in_arg(arg: Any) -> bool:
if check_snowpark_pandas_object_in_arg(v):
return True
else:
- from snowflake.snowpark.modin.pandas import DataFrame, Series
+ from modin.pandas import DataFrame, Series
return isinstance(arg, (DataFrame, Series))
@@ -1289,14 +1292,23 @@ def snowpark_to_pandas_helper(
) -> Union[native_pd.Index, native_pd.DataFrame]:
"""
The helper function retrieves a pandas dataframe from an OrderedDataFrame. Performs necessary type
- conversions for variant types on the client. This function issues 2 queries, one metadata query
- to retrieve the schema and one query to retrieve the data values.
+ conversions including
+ 1. For VARIANT types, OrderedDataFrame.to_pandas may convert datetime like types to string. So we add one `typeof`
+ column for each variant column and use that metadata to convert datetime like types back to their original types.
+ 2. For TIMESTAMP_TZ type, OrderedDataFrame.to_pandas will convert them into the local session timezone and lose the
+ original timezone. So we cast TIMESTAMP_TZ columns to string first and then convert them back after to_pandas to
+ preserve the original timezone. Note that the actual timezone will be lost in Snowflake backend but only the offset
+ preserved.
+ 3. For Timedelta columns, since currently we represent the values using integers, here we need to explicitly cast
+ them back to Timedelta.
Args:
frame: The internal frame to convert to pandas Dataframe (or Index if index_only is true)
index_only: if true, only turn the index columns into a pandas Index
- statement_params: Dictionary of statement level parameters to be passed to conversion function of ordered dataframe abstraction.
- kwargs: Additional keyword-only args to pass to internal `to_pandas` conversion for orderded dataframe abstraction.
+ statement_params: Dictionary of statement level parameters to be passed to conversion function of ordered
+ dataframe abstraction.
+ kwargs: Additional keyword-only args to pass to internal `to_pandas` conversion for ordered dataframe
+ abstraction.
Returns:
pandas dataframe
@@ -1365,7 +1377,7 @@ def snowpark_to_pandas_helper(
)
variant_type_identifiers = list(map(lambda t: t[0], variant_type_columns_info))
- # Step 3: Create for each variant type column a separate type column (append at end), and retrieve data values
+ # Step 3.1: Create for each variant type column a separate type column (append at end), and retrieve data values
# (and types for variant type columns).
variant_type_typeof_identifiers = (
ordered_dataframe.generate_snowflake_quoted_identifiers(
@@ -1384,10 +1396,36 @@ def snowpark_to_pandas_helper(
[typeof(col(id)) for id in variant_type_identifiers],
)
+ # Step 3.2: cast timestamp_tz to string to preserve their original timezone offsets
+ timestamp_tz_identifiers = [
+ info[0]
+ for info in columns_info
+ if info[1] == TimestampType(TimestampTimeZone.TZ)
+ ]
+ timestamp_tz_str_identifiers = (
+ ordered_dataframe.generate_snowflake_quoted_identifiers(
+ pandas_labels=[
+ f"{unquote_name_if_quoted(id)}_str" for id in timestamp_tz_identifiers
+ ],
+ excluded=column_identifiers,
+ )
+ )
+ if len(timestamp_tz_identifiers):
+ ordered_dataframe = append_columns(
+ ordered_dataframe,
+ timestamp_tz_str_identifiers,
+ [
+ to_char(col(id), format="YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM")
+ for id in timestamp_tz_identifiers
+ ],
+ )
+
# ensure that snowpark_df has unique identifiers, so the native pandas DataFrame object created here
# also does have unique column names which is a prerequisite for the post-processing logic following.
assert is_duplicate_free(
- column_identifiers + variant_type_typeof_identifiers
+ column_identifiers
+ + variant_type_typeof_identifiers
+ + timestamp_tz_str_identifiers
), "Snowpark DataFrame to convert must have unique column identifiers"
pandas_df = ordered_dataframe.to_pandas(statement_params=statement_params, **kwargs)
@@ -1400,7 +1438,9 @@ def snowpark_to_pandas_helper(
# Step 3a: post-process variant type columns, if any exist.
id_to_label_mapping = dict(
zip(
- column_identifiers + variant_type_typeof_identifiers,
+ column_identifiers
+ + variant_type_typeof_identifiers
+ + timestamp_tz_str_identifiers,
pandas_df.columns,
)
)
@@ -1439,6 +1479,25 @@ def convert_variant_type_to_pandas(row: native_pd.Series) -> Any:
id_to_label_mapping[quoted_name]
].apply(lambda value: None if value is None else json.loads(value))
+ # Convert timestamp_tz in string back to datetime64tz.
+ if any(
+ dtype == TimestampType(TimestampTimeZone.TZ) for (_, dtype) in columns_info
+ ):
+ id_to_label_mapping = dict(
+ zip(
+ column_identifiers
+ + variant_type_typeof_identifiers
+ + timestamp_tz_str_identifiers,
+ pandas_df.columns,
+ )
+ )
+ for ts_id, ts_str_id in zip(
+ timestamp_tz_identifiers, timestamp_tz_str_identifiers
+ ):
+ pandas_df[id_to_label_mapping[ts_id]] = native_pd.to_datetime(
+ pandas_df[id_to_label_mapping[ts_str_id]]
+ )
+
# Step 5. Return the original amount of columns by stripping any typeof(...) columns appended if
# schema contained VariantType.
downcast_pandas_df = pandas_df[pandas_df.columns[: len(columns_info)]]
@@ -1460,9 +1519,15 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta:
downcast_pandas_df.columns, cached_snowpark_pandas_types
):
if snowpark_pandas_type is not None and snowpark_pandas_type == timedelta_t:
- downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply(
- convert_str_to_timedelta
- )
+ # By default, pandas warns, "A value is trying to be set on a
+ # copy of a slice from a DataFrame" here because we are
+ # assigning a column to downcast_pandas_df, which is a copy of
+ # a slice of pandas_df. We don't care what happens to pandas_df,
+ # so the warning isn't useful to us.
+ with native_pd.option_context("mode.chained_assignment", None):
+ downcast_pandas_df[pandas_label] = pandas_df[pandas_label].apply(
+ convert_str_to_timedelta
+ )
# Step 7. postprocessing for return types
if index_only:
@@ -1493,7 +1558,11 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta:
# multiple timezones. So here we cast the index to the index_type when ret = pd.Index(...) above cannot
# figure out a non-object dtype. Note that the index_type is a logical type may not be 100% accurate.
if is_object_dtype(ret.dtype) and not is_object_dtype(index_type):
- ret = ret.astype(index_type)
+ # TODO: SNOW-1657460 fix index_type for timestamp_tz
+ try:
+ ret = ret.astype(index_type)
+ except ValueError: # e.g., Tz-aware datetime.datetime cannot be converted to datetime64
+ pass
return ret
# to_pandas() does not preserve the index information and will just return a
diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
index 400e98562f9..e971b15b6d6 100644
--- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
+++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
@@ -149,8 +149,6 @@
)
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
AGG_NAME_COL_LABEL,
- GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE,
- GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES,
AggFuncInfo,
AggFuncWithLabel,
AggregateColumnOpParameters,
@@ -161,7 +159,6 @@
convert_agg_func_arg_to_col_agg_func_map,
drop_non_numeric_data_columns,
generate_column_agg_info,
- generate_rowwise_aggregation_function,
get_agg_func_to_col_map,
get_pandas_aggr_func_name,
get_snowflake_agg_func,
@@ -172,6 +169,7 @@
APPLY_LABEL_COLUMN_QUOTED_IDENTIFIER,
APPLY_VALUE_COLUMN_QUOTED_IDENTIFIER,
DEFAULT_UDTF_PARTITION_SIZE,
+ GroupbyApplySortMethod,
check_return_variant_and_get_return_type,
create_udf_for_series_apply,
create_udtf_for_apply_axis_1,
@@ -184,11 +182,7 @@
sort_apply_udtf_result_columns_by_pandas_positions,
)
from snowflake.snowpark.modin.plugin._internal.binary_op_utils import (
- compute_binary_op_between_scalar_and_snowpark_column,
- compute_binary_op_between_snowpark_column_and_scalar,
- compute_binary_op_between_snowpark_columns,
- compute_binary_op_with_fill_value,
- is_binary_op_supported,
+ BinaryOp,
merge_label_and_identifier_pairs,
prepare_binop_pairs_between_dataframe_and_dataframe,
)
@@ -282,6 +276,8 @@
raise_if_to_datetime_not_supported,
timedelta_freq_to_nanos,
to_snowflake_timestamp_format,
+ tz_convert_column,
+ tz_localize_column,
)
from snowflake.snowpark.modin.plugin._internal.transpose_utils import (
clean_up_transpose_result_index_and_labels,
@@ -1854,7 +1850,7 @@ def _binary_op_scalar_rhs(
replace_mapping = {}
data_column_snowpark_pandas_types = []
for identifier in self._modin_frame.data_column_snowflake_quoted_identifiers:
- expression, snowpark_pandas_type = compute_binary_op_with_fill_value(
+ expression, snowpark_pandas_type = BinaryOp.create_with_fill_value(
op=op,
lhs=col(identifier),
lhs_datatype=lambda identifier=identifier: self._modin_frame.get_snowflake_type(
@@ -1863,7 +1859,7 @@ def _binary_op_scalar_rhs(
rhs=pandas_lit(other),
rhs_datatype=lambda: infer_object_type(other),
fill_value=fill_value,
- )
+ ).compute()
replace_mapping[identifier] = expression
data_column_snowpark_pandas_types.append(snowpark_pandas_type)
return SnowflakeQueryCompiler(
@@ -1914,7 +1910,7 @@ def _binary_op_list_like_rhs_axis_0(
replace_mapping = {}
snowpark_pandas_types = []
for identifier in new_frame.data_column_snowflake_quoted_identifiers[:-1]:
- expression, snowpark_pandas_type = compute_binary_op_with_fill_value(
+ expression, snowpark_pandas_type = BinaryOp.create_with_fill_value(
op=op,
lhs=col(identifier),
lhs_datatype=lambda identifier=identifier: new_frame.get_snowflake_type(
@@ -1923,7 +1919,7 @@ def _binary_op_list_like_rhs_axis_0(
rhs=col(other_identifier),
rhs_datatype=lambda: new_frame.get_snowflake_type(other_identifier),
fill_value=fill_value,
- )
+ ).compute()
replace_mapping[identifier] = expression
snowpark_pandas_types.append(snowpark_pandas_type)
@@ -1986,7 +1982,7 @@ def _binary_op_list_like_rhs_axis_1(
# rhs is not guaranteed to be a scalar value - it can be a list-like as well.
# Convert all list-like objects to a list.
rhs_lit = pandas_lit(rhs) if is_scalar(rhs) else pandas_lit(rhs.tolist())
- expression, snowpark_pandas_type = compute_binary_op_with_fill_value(
+ expression, snowpark_pandas_type = BinaryOp.create_with_fill_value(
op,
lhs=lhs,
lhs_datatype=lambda identifier=identifier: self._modin_frame.get_snowflake_type(
@@ -1995,7 +1991,7 @@ def _binary_op_list_like_rhs_axis_1(
rhs=rhs_lit,
rhs_datatype=lambda rhs=rhs: infer_object_type(rhs),
fill_value=fill_value,
- )
+ ).compute()
replace_mapping[identifier] = expression
snowpark_pandas_types.append(snowpark_pandas_type)
@@ -2041,8 +2037,8 @@ def binary_op(
# Native pandas does not support binary operations between a Series and a list-like object.
from modin.pandas import Series
+ from modin.pandas.dataframe import DataFrame
- from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.utils import is_scalar
# fail explicitly for unsupported scenarios
@@ -2056,7 +2052,7 @@ def binary_op(
# match pandas documentation; hence it is omitted in the Snowpark pandas implementation.
raise ValueError("Only scalars can be used as fill_value.")
- if not is_binary_op_supported(op):
+ if not BinaryOp.is_binary_op_supported(op):
ErrorMessage.not_implemented(
f"Snowpark pandas doesn't yet support '{op}' binary operation"
)
@@ -2121,7 +2117,7 @@ def binary_op(
)[0]
# add new column with result as unnamed
- new_column_expr, snowpark_pandas_type = compute_binary_op_with_fill_value(
+ new_column_expr, snowpark_pandas_type = BinaryOp.create_with_fill_value(
op=op,
lhs=col(lhs_quoted_identifier),
lhs_datatype=lambda: aligned_frame.get_snowflake_type(
@@ -2132,7 +2128,7 @@ def binary_op(
rhs_quoted_identifier
),
fill_value=fill_value,
- )
+ ).compute()
# name is dropped when names of series differ. A dropped name is using unnamed series label.
new_column_name = (
@@ -3557,42 +3553,22 @@ def convert_func_to_agg_func_info(
agg_col_ops, new_data_column_index_names = generate_column_agg_info(
internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby
)
- # Get the column aggregation functions used to check if the function
- # preserves Snowpark pandas types.
- agg_col_funcs = []
- for _, func in column_to_agg_func.items():
- if is_list_like(func) and not is_named_tuple(func):
- for fn in func:
- agg_col_funcs.append(fn.func)
- else:
- agg_col_funcs.append(func.func)
# the pandas label and quoted identifier generated for each result column
# after aggregation will be used as new pandas label and quoted identifiers.
new_data_column_pandas_labels = []
new_data_column_quoted_identifiers = []
new_data_column_snowpark_pandas_types = []
- for i in range(len(agg_col_ops)):
- col_agg_op = agg_col_ops[i]
- col_agg_func = agg_col_funcs[i]
- new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label)
+ for agg_col_op in agg_col_ops:
+ new_data_column_pandas_labels.append(agg_col_op.agg_pandas_label)
new_data_column_quoted_identifiers.append(
- col_agg_op.agg_snowflake_quoted_identifier
+ agg_col_op.agg_snowflake_quoted_identifier
+ )
+ new_data_column_snowpark_pandas_types.append(
+ agg_col_op.data_type
+ if isinstance(agg_col_op.data_type, SnowparkPandasType)
+ and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types
+ else None
)
- if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE:
- new_data_column_snowpark_pandas_types.append(
- col_agg_op.data_type
- if isinstance(col_agg_op.data_type, SnowparkPandasType)
- else None
- )
- elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES:
- # In the case where the aggregation overrides the type of the output data column
- # (e.g. any always returns boolean data columns), set the output Snowpark pandas type
- # of the given column to None
- new_data_column_snowpark_pandas_types.append(None) # type: ignore
- else:
- self._raise_not_implemented_error_for_timedelta()
- new_data_column_snowpark_pandas_types = None # type: ignore
-
# The ordering of the named aggregations is changed by us when we process
# the agg_kwargs into the func dict (named aggregations on the same
# column are moved to be contiguous, see groupby.py::aggregate for an
@@ -3645,7 +3621,7 @@ def convert_func_to_agg_func_info(
),
agg_pandas_label=None,
agg_snowflake_quoted_identifier=row_position_quoted_identifier,
- snowflake_agg_func=min_,
+ snowflake_agg_func=get_snowflake_agg_func("min", agg_kwargs={}, axis=0),
ordering_columns=internal_frame.ordering_columns,
)
agg_col_ops.append(row_position_agg_column_op)
@@ -3757,6 +3733,8 @@ def groupby_apply(
agg_args: Any,
agg_kwargs: dict[str, Any],
series_groupby: bool,
+ force_single_group: bool = False,
+ force_list_like_to_series: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Group according to `by` and `level`, apply a function to each group, and combine the results.
@@ -3777,6 +3755,10 @@ def groupby_apply(
Keyword arguments to pass to agg_func when applying it to each group.
series_groupby:
Whether we are performing a SeriesGroupBy.apply() instead of a DataFrameGroupBy.apply()
+ force_single_group:
+ Force single group (empty set of group by labels) useful for DataFrame.apply() with axis=0
+ force_list_like_to_series:
+ Force the function result to series if it is list-like
Returns
-------
@@ -3804,15 +3786,23 @@ def groupby_apply(
dropna = groupby_kwargs.get("dropna", True)
group_keys = groupby_kwargs.get("group_keys", False)
- by_pandas_labels = extract_groupby_column_pandas_labels(self, by, level)
+ by_pandas_labels = (
+ []
+ if force_single_group
+ else extract_groupby_column_pandas_labels(self, by, level)
+ )
- by_snowflake_quoted_identifiers_list = [
- quoted_identifier
- for entry in self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
- by_pandas_labels
- )
- for quoted_identifier in entry
- ]
+ by_snowflake_quoted_identifiers_list = (
+ []
+ if force_single_group
+ else [
+ quoted_identifier
+ for entry in self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
+ by_pandas_labels
+ )
+ for quoted_identifier in entry
+ ]
+ )
snowflake_type_map = self._modin_frame.quoted_identifier_to_snowflake_type()
@@ -3846,11 +3836,14 @@ def groupby_apply(
],
session=self._modin_frame.ordered_dataframe.session,
series_groupby=series_groupby,
- by_types=[
+ by_types=[]
+ if force_single_group
+ else [
snowflake_type_map[quoted_identifier]
for quoted_identifier in by_snowflake_quoted_identifiers_list
],
existing_identifiers=self._modin_frame.ordered_dataframe._dataframe_ref.snowflake_quoted_identifiers,
+ force_list_like_to_series=force_list_like_to_series,
)
new_internal_df = self._modin_frame.ensure_row_position_column()
@@ -3922,9 +3915,9 @@ def groupby_apply(
*new_internal_df.index_column_snowflake_quoted_identifiers,
*input_data_column_identifiers,
).over(
- partition_by=[
- *by_snowflake_quoted_identifiers_list,
- ],
+ partition_by=None
+ if force_single_group
+ else [*by_snowflake_quoted_identifiers_list],
order_by=row_position_snowflake_quoted_identifier,
),
)
@@ -4066,7 +4059,9 @@ def groupby_apply(
ordered_dataframe=ordered_dataframe,
agg_func=agg_func,
by_snowflake_quoted_identifiers_list=by_snowflake_quoted_identifiers_list,
- sort_method=groupby_apply_sort_method(
+ sort_method=GroupbyApplySortMethod.ORIGINAL_ROW_ORDER
+ if force_single_group
+ else groupby_apply_sort_method(
sort,
group_keys,
original_row_position_snowflake_quoted_identifier,
@@ -5639,8 +5634,6 @@ def agg(
args: the arguments passed for the aggregation
kwargs: keyword arguments passed for the aggregation function.
"""
- self._raise_not_implemented_error_for_timedelta()
-
numeric_only = kwargs.get("numeric_only", False)
# Call fallback if the aggregation function passed in the arg is currently not supported
# by snowflake engine.
@@ -5686,6 +5679,11 @@ def agg(
not is_list_like(value) for value in func.values()
)
if axis == 1:
+ if any(
+ isinstance(t, TimedeltaType)
+ for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values()
+ ):
+ ErrorMessage.not_implemented_for_timedelta("agg(axis=1)")
if self.is_multiindex():
# TODO SNOW-1010307 fix axis=1 behavior with MultiIndex
ErrorMessage.not_implemented(
@@ -5743,9 +5741,9 @@ def agg(
pandas_column_labels=frame.data_column_pandas_labels,
)
if agg_arg in ("idxmin", "idxmax")
- else generate_rowwise_aggregation_function(agg_arg, kwargs)(
- *(col(c) for c in data_col_identifiers)
- )
+ else get_snowflake_agg_func(
+ agg_arg, kwargs, axis=1
+ ).snowpark_aggregation(*(col(c) for c in data_col_identifiers))
for agg_arg in agg_args
}
pandas_labels = list(agg_col_map.keys())
@@ -5865,7 +5863,13 @@ def generate_agg_qc(
index_column_snowflake_quoted_identifiers=[
agg_name_col_quoted_identifier
],
- data_column_types=None,
+ data_column_types=[
+ col.data_type
+ if isinstance(col.data_type, SnowparkPandasType)
+ and col.snowflake_agg_func.preserves_snowpark_pandas_types
+ else None
+ for col in col_agg_infos
+ ],
index_column_types=None,
)
return SnowflakeQueryCompiler(single_agg_dataframe)
@@ -7377,28 +7381,34 @@ def merge_asof(
SnowflakeQueryCompiler
"""
# TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation
- if (
- by
- or left_by
- or right_by
- or left_index
- or right_index
- or tolerance
- or suffixes != ("_x", "_y")
- ):
+ if left_index or right_index or tolerance or suffixes != ("_x", "_y"):
ErrorMessage.not_implemented(
"Snowpark pandas merge_asof method does not currently support parameters "
- + "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
- + "'suffixes', or 'tolerance'"
+ + "'left_index', 'right_index', 'suffixes', or 'tolerance'"
)
if direction not in ("backward", "forward"):
ErrorMessage.not_implemented(
"Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'"
)
+ if direction == "backward":
+ match_comparator = (
+ MatchComparator.GREATER_THAN_OR_EQUAL_TO
+ if allow_exact_matches
+ else MatchComparator.GREATER_THAN
+ )
+ else:
+ match_comparator = (
+ MatchComparator.LESS_THAN_OR_EQUAL_TO
+ if allow_exact_matches
+ else MatchComparator.LESS_THAN
+ )
+
left_frame = self._modin_frame
right_frame = right._modin_frame
- left_keys, right_keys = join_utils.get_join_keys(
+ # Get the left and right matching key and quoted identifier corresponding to the match_condition
+ # There will only be matching key/identifier for each table as there is only a single match condition
+ left_match_keys, right_match_keys = join_utils.get_join_keys(
left=left_frame,
right=right_frame,
on=on,
@@ -7407,42 +7417,62 @@ def merge_asof(
left_index=left_index,
right_index=right_index,
)
- left_match_col = (
+ left_match_identifier = (
left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
- left_keys
+ left_match_keys
)[0][0]
)
- right_match_col = (
+ right_match_identifier = (
right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
- right_keys
+ right_match_keys
)[0][0]
)
-
- if direction == "backward":
- match_comparator = (
- MatchComparator.GREATER_THAN_OR_EQUAL_TO
- if allow_exact_matches
- else MatchComparator.GREATER_THAN
+ coalesce_config = join_utils.get_coalesce_config(
+ left_keys=left_match_keys,
+ right_keys=right_match_keys,
+ external_join_keys=[],
+ )
+
+ # Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition
+ if by or (left_by and right_by):
+ left_on_keys, right_on_keys = join_utils.get_join_keys(
+ left=left_frame,
+ right=right_frame,
+ on=by,
+ left_on=left_by,
+ right_on=right_by,
+ )
+ left_on_identifiers = [
+ ids[0]
+ for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
+ left_on_keys
+ )
+ ]
+ right_on_identifiers = [
+ ids[0]
+ for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
+ right_on_keys
+ )
+ ]
+ coalesce_config.extend(
+ join_utils.get_coalesce_config(
+ left_keys=left_on_keys,
+ right_keys=right_on_keys,
+ external_join_keys=[],
+ )
)
else:
- match_comparator = (
- MatchComparator.LESS_THAN_OR_EQUAL_TO
- if allow_exact_matches
- else MatchComparator.LESS_THAN
- )
-
- coalesce_config = join_utils.get_coalesce_config(
- left_keys=left_keys, right_keys=right_keys, external_join_keys=[]
- )
+ left_on_identifiers = []
+ right_on_identifiers = []
joined_frame, _ = join_utils.join(
left=left_frame,
right=right_frame,
+ left_on=left_on_identifiers,
+ right_on=right_on_identifiers,
how="asof",
- left_on=[left_match_col],
- right_on=[right_match_col],
- left_match_col=left_match_col,
- right_match_col=right_match_col,
+ left_match_col=left_match_identifier,
+ right_match_col=right_match_identifier,
match_comparator=match_comparator,
join_key_coalesce_config=coalesce_config,
sort=True,
@@ -7888,11 +7918,6 @@ def apply(
"""
self._raise_not_implemented_error_for_timedelta()
- # axis=0 is not supported, raise error.
- if axis == 0:
- ErrorMessage.not_implemented(
- "Snowpark pandas apply API doesn't yet support axis == 0"
- )
# Only callables are supported for axis=1 mode for now.
if not callable(func) and not isinstance(func, UserDefinedFunction):
ErrorMessage.not_implemented(
@@ -7909,56 +7934,260 @@ def apply(
"Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'"
)
- # get input types of all data columns from the dataframe directly
- input_types = self._modin_frame.get_snowflake_type(
- self._modin_frame.data_column_snowflake_quoted_identifiers
- )
+ if axis == 0:
+ frame = self._modin_frame
- from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
+ # To apply function to Dataframe with axis=0, we repurpose the groupby apply function by taking each
+ # column, as a series, and treat as a single group to apply function. Then collect the column results to
+ # join together for the final result.
+ col_results = []
- # current columns
- column_index = try_convert_index_to_native(self._modin_frame.data_columns_index)
+ # If raw, then pass numpy ndarray rather than pandas Series as input to the apply function.
+ if raw:
- # Extract return type from annotations (or lookup for known pandas functions) for func object,
- # if not return type could be extracted the variable will hold None.
- return_type = deduce_return_type_from_function(func)
+ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no cover: adding type hint causes an error when creating udtf. also, skip coverage for this function because coverage tools can't tell that we're executing this function because we execute it in a UDTF.
+ raw_input_obj = args[0].to_numpy()
+ args = (raw_input_obj,) + args[1:]
+ return func(*args, **kwargs)
- # Check whether return_type has been extracted. If return type is not
- # a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to
- # be performed which means that the result of df.apply(axis=1) is always a Series object.
- if return_type and not (
- isinstance(return_type, PandasSeriesType)
- or isinstance(return_type, ArrayType)
- ):
- return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1(
- func,
- column_index,
- input_types,
- return_type,
- udf_args=args,
- udf_kwargs=kwargs,
- session=self._modin_frame.ordered_dataframe.session,
- )
+ agg_func = wrapped_func
+ else:
+ agg_func = func
+
+ # Accumulate indices of the column results.
+ col_result_indexes = []
+ # Accumulate "is scalar" flags for the column results.
+ col_result_scalars = []
+
+ # Loop through each data column of the original df frame
+ for (column_index, data_column_pair) in enumerate(
+ zip(
+ frame.data_column_pandas_labels,
+ frame.data_column_snowflake_quoted_identifiers,
+ )
+ ):
+ (
+ data_column_pandas_label,
+ data_column_snowflake_quoted_identifier,
+ ) = data_column_pair
+
+ # Create a frame for the current data column which we will be passed to the apply function below.
+ # Note that we maintain the original index because the apply function may access via the index.
+ data_col_qc = self.take_2d_positional(
+ index=slice(None, None), columns=[column_index]
+ )
+
+ data_col_frame = data_col_qc._modin_frame
+
+ data_col_qc = data_col_qc.groupby_apply(
+ by=[],
+ agg_func=agg_func,
+ axis=0,
+ groupby_kwargs={"as_index": False, "dropna": False},
+ agg_args=args,
+ agg_kwargs=kwargs,
+ series_groupby=True,
+ force_single_group=True,
+ force_list_like_to_series=True,
+ )
+
+ data_col_result_frame = data_col_qc._modin_frame
+
+ # Set the index names and corresponding data column pandas label on the result.
+ data_col_result_frame = InternalFrame.create(
+ ordered_dataframe=data_col_result_frame.ordered_dataframe,
+ data_column_snowflake_quoted_identifiers=data_col_result_frame.data_column_snowflake_quoted_identifiers,
+ data_column_pandas_labels=[data_column_pandas_label],
+ data_column_pandas_index_names=data_col_frame.data_column_pandas_index_names,
+ data_column_types=None,
+ index_column_snowflake_quoted_identifiers=data_col_result_frame.index_column_snowflake_quoted_identifiers,
+ index_column_pandas_labels=data_col_result_frame.index_column_pandas_labels,
+ index_column_types=data_col_result_frame.cached_index_column_snowpark_pandas_types,
+ )
+
+ data_col_result_index = (
+ data_col_result_frame.index_columns_pandas_index()
+ )
+ col_result_indexes.append(data_col_result_index)
+ # TODO: For functions like np.sum, when supported, we can know upfront the result is a scalar
+ # so don't need to look at the index.
+ col_result_scalars.append(
+ len(data_col_result_index) == 1 and data_col_result_index[0] == -1
+ )
+ col_results.append(SnowflakeQueryCompiler(data_col_result_frame))
+
+ result_is_series = False
+
+ if len(col_results) == 1:
+ result_is_series = col_result_scalars[0]
+ qc_result = col_results[0]
+
+ # Squeeze to series if it is single column
+ qc_result = qc_result.columnarize()
+ if col_result_scalars[0]:
+ qc_result = qc_result.reset_index(drop=True)
+ else:
+ single_row_output = all(len(index) == 1 for index in col_result_indexes)
+ if single_row_output:
+ all_scalar_output = all(
+ is_scalar for is_scalar in col_result_scalars
+ )
+ if all_scalar_output:
+ # If the apply function maps all columns to a scalar value, then we need to join them together
+ # to return as a Series result.
+
+ # Ensure all column results have the same column name so concat will be aligned.
+ for i, qc in enumerate(col_results):
+ col_results[i] = qc.set_columns([0])
+
+ qc_result = col_results[0].concat(
+ axis=0,
+ other=col_results[1:],
+ keys=frame.data_column_pandas_labels,
+ )
+ qc_frame = qc_result._modin_frame
+
+ # Drop the extraneous index column from the original result series.
+ qc_result = SnowflakeQueryCompiler(
+ InternalFrame.create(
+ ordered_dataframe=qc_frame.ordered_dataframe,
+ data_column_snowflake_quoted_identifiers=qc_frame.data_column_snowflake_quoted_identifiers,
+ data_column_pandas_labels=qc_frame.data_column_pandas_labels,
+ data_column_pandas_index_names=qc_frame.data_column_pandas_index_names,
+ data_column_types=qc_frame.cached_data_column_snowpark_pandas_types,
+ index_column_snowflake_quoted_identifiers=qc_frame.index_column_snowflake_quoted_identifiers[
+ :-1
+ ],
+ index_column_pandas_labels=qc_frame.index_column_pandas_labels[
+ :-1
+ ],
+ index_column_types=qc_frame.cached_index_column_snowpark_pandas_types[
+ :-1
+ ],
+ )
+ )
+
+ result_is_series = True
+ else:
+ no_scalar_output = all(
+ not is_scalar for is_scalar in col_result_scalars
+ )
+ if no_scalar_output:
+ # Output is Dataframe
+ all_same_index = col_result_indexes.count(
+ col_result_indexes[0]
+ ) == len(col_result_indexes)
+ qc_result = col_results[0].concat(
+ axis=1, other=col_results[1:], sort=not all_same_index
+ )
+ else:
+ # If there's a mix of scalar and pd.Series output from the apply func, pandas stores the
+ # pd.Series output as the value, which we do not currently support.
+ ErrorMessage.not_implemented(
+ "Nested pd.Series in result is not supported in DataFrame.apply(axis=0)"
+ )
+ else:
+ if any(is_scalar for is_scalar in col_result_scalars):
+ # If there's a mix of scalar and pd.Series output from the apply func, pandas stores the
+ # pd.Series output as the value, which we do not currently support.
+ ErrorMessage.not_implemented(
+ "Nested pd.Series in result is not supported in DataFrame.apply(axis=0)"
+ )
+
+ duplicate_index_values = not all(
+ len(i) == len(set(i)) for i in col_result_indexes
+ )
+
+ # If there are duplicate index values then align on the index for matching results with Pandas.
+ if duplicate_index_values:
+ curr_frame = col_results[0]._modin_frame
+ for next_qc in col_results[1:]:
+ curr_frame = join_utils.align(
+ curr_frame, next_qc._modin_frame, [], [], how="left"
+ ).result_frame
+ qc_result = SnowflakeQueryCompiler(curr_frame)
+ else:
+ # If there are multiple output series with different indices, then line them up as a series output.
+ all_same_index = all(
+ all(i == col_result_indexes[0]) for i in col_result_indexes
+ )
+ # If the col results all have same index then we keep the existing index ordering.
+ qc_result = col_results[0].concat(
+ axis=1, other=col_results[1:], sort=not all_same_index
+ )
+
+ # If result should be Series then change the data column label appropriately.
+ if result_is_series:
+ qc_result_frame = qc_result._modin_frame
+ qc_result = SnowflakeQueryCompiler(
+ InternalFrame.create(
+ ordered_dataframe=qc_result_frame.ordered_dataframe,
+ data_column_snowflake_quoted_identifiers=qc_result_frame.data_column_snowflake_quoted_identifiers,
+ data_column_pandas_labels=[MODIN_UNNAMED_SERIES_LABEL],
+ data_column_pandas_index_names=qc_result_frame.data_column_pandas_index_names,
+ data_column_types=qc_result_frame.cached_data_column_snowpark_pandas_types,
+ index_column_snowflake_quoted_identifiers=qc_result_frame.index_column_snowflake_quoted_identifiers,
+ index_column_pandas_labels=qc_result_frame.index_column_pandas_labels,
+ index_column_types=qc_result_frame.cached_index_column_snowpark_pandas_types,
+ )
+ )
+
+ return qc_result
else:
- # Issue actionable warning for users to consider annotating UDF with type annotations
- # for better performance.
- function_name = (
- func.__name__ if isinstance(func, Callable) else str(func) # type: ignore[arg-type]
+ # get input types of all data columns from the dataframe directly
+ input_types = self._modin_frame.get_snowflake_type(
+ self._modin_frame.data_column_snowflake_quoted_identifiers
)
- WarningMessage.single_warning(
- f"Function {function_name} passed to apply does not have type annotations,"
- f" or Snowpark pandas could not extract type annotations. Executing apply"
- f" in slow code path which may result in decreased performance. "
- f"To disable this warning and improve performance, consider annotating"
- f" {function_name} with type annotations."
+
+ from snowflake.snowpark.modin.pandas.utils import (
+ try_convert_index_to_native,
)
- # Result may need to get expanded into multiple columns, or return type of func is not known.
- # Process using UDTF together with dynamic pivot for either case.
- return self._apply_with_udtf_and_dynamic_pivot_along_axis_1(
- func, raw, result_type, args, column_index, input_types, **kwargs
+ # current columns
+ column_index = try_convert_index_to_native(
+ self._modin_frame.data_columns_index
)
+ # Extract return type from annotations (or lookup for known pandas functions) for func object,
+ # if not return type could be extracted the variable will hold None.
+ return_type = deduce_return_type_from_function(func)
+
+ # Check whether return_type has been extracted. If return type is not
+ # a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to
+ # be performed which means that the result of df.apply(axis=1) is always a Series object.
+ if return_type and not (
+ isinstance(return_type, PandasSeriesType)
+ or isinstance(return_type, ArrayType)
+ ):
+ return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1(
+ func,
+ column_index,
+ input_types,
+ return_type,
+ udf_args=args,
+ udf_kwargs=kwargs,
+ session=self._modin_frame.ordered_dataframe.session,
+ )
+ else:
+ # Issue actionable warning for users to consider annotating UDF with type annotations
+ # for better performance.
+ function_name = (
+ func.__name__ if isinstance(func, Callable) else str(func) # type: ignore[arg-type]
+ )
+ WarningMessage.single_warning(
+ f"Function {function_name} passed to apply does not have type annotations,"
+ f" or Snowpark pandas could not extract type annotations. Executing apply"
+ f" in slow code path which may result in decreased performance. "
+ f"To disable this warning and improve performance, consider annotating"
+ f" {function_name} with type annotations."
+ )
+
+ # Result may need to get expanded into multiple columns, or return type of func is not known.
+ # Process using UDTF together with dynamic pivot for either case.
+ return self._apply_with_udtf_and_dynamic_pivot_along_axis_1(
+ func, raw, result_type, args, column_index, input_types, **kwargs
+ )
+
def applymap(
self,
func: AggFuncType,
@@ -8912,7 +9141,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler":
SnowflakeQueryCompiler
Transposed new QueryCompiler object.
"""
- self._raise_not_implemented_error_for_timedelta()
+ if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1:
+ # In this case, transpose may lose types.
+ self._raise_not_implemented_error_for_timedelta()
frame = self._modin_frame
@@ -10548,7 +10779,7 @@ def _make_discrete_difference_expression(
snowpark_pandas_type=None,
)
else:
- return compute_binary_op_between_snowpark_columns(
+ return BinaryOp.create(
"sub",
col(snowflake_quoted_identifier),
lambda: column_datatype,
@@ -10560,7 +10791,7 @@ def _make_discrete_difference_expression(
)
),
lambda: column_datatype,
- )
+ ).compute()
else:
# periods is the number of columns to *go back*.
@@ -10609,13 +10840,13 @@ def _make_discrete_difference_expression(
col1 = cast(col1, IntegerType())
if isinstance(col2_dtype, BooleanType):
col2 = cast(col2, IntegerType())
- return compute_binary_op_between_snowpark_columns(
+ return BinaryOp.create(
"sub",
col1,
lambda: col1_dtype,
col2,
lambda: col2_dtype,
- )
+ ).compute()
def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler":
"""
@@ -12296,8 +12527,6 @@ def _quantiles_single_col(
column would allow us to create an accurate row position column, but would require a
potentially expensive JOIN operator afterwards to apply the correct index labels.
"""
- self._raise_not_implemented_error_for_timedelta()
-
assert len(self._modin_frame.data_column_pandas_labels) == 1
if index is not None:
@@ -12362,7 +12591,7 @@ def _quantiles_single_col(
],
index_column_pandas_labels=[None],
index_column_snowflake_quoted_identifiers=[index_identifier],
- data_column_types=None,
+ data_column_types=original_frame.cached_data_column_snowpark_pandas_types,
index_column_types=None,
)
# We cannot call astype() directly to convert an index column, so we replicate
@@ -13396,6 +13625,16 @@ def _window_agg(
}
).frame
else:
+ snowflake_agg_func = get_snowflake_agg_func(agg_func, agg_kwargs, axis=0)
+ if snowflake_agg_func is None:
+ # We don't have test coverage for this situation because we
+ # test individual rolling and expanding methods we've implemented,
+ # like rolling_sum(), but other rolling methods raise
+ # NotImplementedError immediately. We also don't support rolling
+ # agg(), which might take us here.
+ ErrorMessage.not_implemented( # pragma: no cover
+ f"Window aggregation does not support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)}"
+ )
new_frame = frame.update_snowflake_quoted_identifiers_with_expressions(
{
# If aggregation is count use count on row_position_quoted_identifier
@@ -13406,7 +13645,7 @@ def _window_agg(
if agg_func == "count"
else count(col(quoted_identifier)).over(window_expr)
>= min_periods,
- get_snowflake_agg_func(agg_func, agg_kwargs)(
+ snowflake_agg_func.snowpark_aggregation(
# Expanding is cumulative so replace NULL with 0 for sum aggregation
builtin("zeroifnull")(col(quoted_identifier))
if window_func == WindowFunction.EXPANDING
@@ -14213,7 +14452,7 @@ def _binary_op_between_dataframe_and_series_along_axis_0(
)
)
- # Lazify type map here for calling compute_binary_op_between_snowpark_columns.
+ # Lazify type map here for calling binaryOp.compute.
def create_lazy_type_functions(
identifiers: list[str],
) -> list[DataTypeGetter]:
@@ -14243,12 +14482,9 @@ def create_lazy_type_functions(
replace_mapping = {}
snowpark_pandas_types = []
for left, left_datatype in zip(left_result_data_identifiers, left_datatypes):
- (
- expression,
- snowpark_pandas_type,
- ) = compute_binary_op_between_snowpark_columns(
+ (expression, snowpark_pandas_type,) = BinaryOp.create(
op, col(left), left_datatype, col(right), right_datatype
- )
+ ).compute()
snowpark_pandas_types.append(snowpark_pandas_type)
replace_mapping[left] = expression
update_result = joined_frame.result_frame.update_snowflake_quoted_identifiers_with_expressions(
@@ -14363,8 +14599,6 @@ def idxmax(
Returns:
SnowflakeQueryCompiler
"""
- self._raise_not_implemented_error_for_timedelta()
-
return self._idxmax_idxmin(
func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only
)
@@ -14389,8 +14623,6 @@ def idxmin(
Returns:
SnowflakeQueryCompiler
"""
- self._raise_not_implemented_error_for_timedelta()
-
return self._idxmax_idxmin(
func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only
)
@@ -14507,14 +14739,14 @@ def infer_sorted_column_labels(
replace_mapping = {}
data_column_snowpark_pandas_types = []
for p in left_right_pairs:
- result_expression, snowpark_pandas_type = compute_binary_op_with_fill_value(
+ result_expression, snowpark_pandas_type = BinaryOp.create_with_fill_value(
op=op,
lhs=p.lhs,
lhs_datatype=p.lhs_datatype,
rhs=p.rhs,
rhs_datatype=p.rhs_datatype,
fill_value=fill_value,
- )
+ ).compute()
replace_mapping[p.identifier] = result_expression
data_column_snowpark_pandas_types.append(snowpark_pandas_type)
# Create restricted frame with only combined / replaced labels.
@@ -14781,19 +15013,19 @@ def infer_sorted_column_labels(
snowpark_pandas_labels = []
for label, identifier in overlapping_pairs:
expression, new_type = (
- compute_binary_op_between_scalar_and_snowpark_column(
+ BinaryOp.create_with_lhs_scalar(
op,
series.loc[label],
col(identifier),
datatype_getters[identifier],
- )
+ ).compute()
if squeeze_self
- else compute_binary_op_between_snowpark_column_and_scalar(
+ else BinaryOp.create_with_rhs_scalar(
op,
col(identifier),
datatype_getters[identifier],
series.loc[label],
- )
+ ).compute()
)
snowpark_pandas_labels.append(new_type)
replace_mapping[identifier] = expression
@@ -16454,34 +16686,59 @@ def dt_tz_localize(
tz: Union[str, tzinfo],
ambiguous: str = "raise",
nonexistent: str = "raise",
- ) -> None:
+ include_index: bool = False,
+ ) -> "SnowflakeQueryCompiler":
"""
Localize tz-naive to tz-aware.
Args:
tz : str, pytz.timezone, optional
ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise"
nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise"
+ include_index: Whether to include the index columns in the operation.
Returns:
BaseQueryCompiler
New QueryCompiler containing values with localized time zone.
"""
- ErrorMessage.not_implemented(
- "Snowpark pandas doesn't yet support the method 'Series.dt.tz_localize'"
+ dtype = self.index_dtypes[0] if include_index else self.dtypes[0]
+ if not include_index:
+ method_name = "Series.dt.tz_localize"
+ else:
+ assert is_datetime64_any_dtype(dtype), "column must be datetime"
+ method_name = "DatetimeIndex.tz_localize"
+
+ if not isinstance(ambiguous, str) or ambiguous != "raise":
+ ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
+ if not isinstance(nonexistent, str) or nonexistent != "raise":
+ ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)
+
+ return SnowflakeQueryCompiler(
+ self._modin_frame.apply_snowpark_function_to_columns(
+ lambda column: tz_localize_column(column, tz),
+ include_index,
+ )
)
- def dt_tz_convert(self, tz: Union[str, tzinfo]) -> None:
+ def dt_tz_convert(
+ self,
+ tz: Union[str, tzinfo],
+ include_index: bool = False,
+ ) -> "SnowflakeQueryCompiler":
"""
Convert time-series data to the specified time zone.
Args:
tz : str, pytz.timezone
+ include_index: Whether to include the index columns in the operation.
Returns:
A new QueryCompiler containing values with converted time zone.
"""
- ErrorMessage.not_implemented(
- "Snowpark pandas doesn't yet support the method 'Series.dt.tz_convert'"
+ return SnowflakeQueryCompiler(
+ self._modin_frame.apply_snowpark_function_to_columns(
+ lambda column: tz_convert_column(column, tz),
+ include_index,
+ )
)
def dt_ceil(
@@ -16524,9 +16781,9 @@ def dt_ceil(
"column must be datetime or timedelta"
) # pragma: no cover
- if ambiguous != "raise":
+ if not isinstance(ambiguous, str) or ambiguous != "raise":
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
- if nonexistent != "raise":
+ if not isinstance(nonexistent, str) or nonexistent != "raise":
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)
if is_datetime64_any_dtype(dtype):
@@ -16604,9 +16861,10 @@ def dt_round(
raise AssertionError(
"column must be datetime or timedelta"
) # pragma: no cover
- if ambiguous != "raise":
+
+ if not isinstance(ambiguous, str) or ambiguous != "raise":
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
- if nonexistent != "raise":
+ if not isinstance(nonexistent, str) or nonexistent != "raise":
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)
if is_datetime64_any_dtype(dtype):
@@ -16762,9 +17020,10 @@ def dt_floor(
raise AssertionError(
"column must be datetime or timedelta"
) # pragma: no cover
- if ambiguous != "raise":
+
+ if not isinstance(ambiguous, str) or ambiguous != "raise":
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
- if nonexistent != "raise":
+ if not isinstance(nonexistent, str) or nonexistent != "raise":
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)
if is_datetime64_any_dtype(dtype):
@@ -17246,9 +17505,11 @@ def equals(
)
replace_mapping = {
- p.identifier: compute_binary_op_between_snowpark_columns(
+ p.identifier: BinaryOp.create(
"equal_null", p.lhs, p.lhs_datatype, p.rhs, p.rhs_datatype
- ).snowpark_column
+ )
+ .compute()
+ .snowpark_column
for p in left_right_pairs
}
@@ -17776,7 +18037,7 @@ def compare(
right_identifier = result_column_mapper.right_quoted_identifiers_map[
right_identifier
]
- op_result = compute_binary_op_between_snowpark_columns(
+ op_result = BinaryOp.create(
op="equal_null",
first_operand=col(left_identifier),
first_datatype=functools.partial(
@@ -17786,7 +18047,7 @@ def compare(
second_datatype=functools.partial(
lambda col: result_frame.get_snowflake_type(col), right_identifier
),
- )
+ ).compute()
binary_op_result = binary_op_result.append_column(
str(left_pandas_label) + "_comparison_result",
op_result.snowpark_column,
@@ -17897,19 +18158,23 @@ def compare(
right_identifier
]
- cols_equal = compute_binary_op_between_snowpark_columns(
- op="equal_null",
- first_operand=col(left_mappped_identifier),
- first_datatype=functools.partial(
- lambda col: result_frame.get_snowflake_type(col),
- left_mappped_identifier,
- ),
- second_operand=col(right_mapped_identifier),
- second_datatype=functools.partial(
- lambda col: result_frame.get_snowflake_type(col),
- right_mapped_identifier,
- ),
- ).snowpark_column
+ cols_equal = (
+ BinaryOp.create(
+ op="equal_null",
+ first_operand=col(left_mappped_identifier),
+ first_datatype=functools.partial(
+ lambda col: result_frame.get_snowflake_type(col),
+ left_mappped_identifier,
+ ),
+ second_operand=col(right_mapped_identifier),
+ second_datatype=functools.partial(
+ lambda col: result_frame.get_snowflake_type(col),
+ right_mapped_identifier,
+ ),
+ )
+ .compute()
+ .snowpark_column
+ )
# Add a column containing the values from `self`, but replace
# matching values with null.
diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/base.py b/src/snowflake/snowpark/modin/plugin/docstrings/base.py
index af50e0379dd..4044f7b675f 100644
--- a/src/snowflake/snowpark/modin/plugin/docstrings/base.py
+++ b/src/snowflake/snowpark/modin/plugin/docstrings/base.py
@@ -2832,6 +2832,7 @@ def shift():
"""
Implement shared functionality between DataFrame and Series for shift. axis argument is only relevant for
Dataframe, and should be 0 for Series.
+
Args:
periods : int | Sequence[int]
Number of periods to shift. Can be positive or negative. If an iterable of ints,
diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
index 6d79d07ab84..f7e93e6c2df 100644
--- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
+++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
@@ -730,7 +730,7 @@ def apply():
Parameters
----------
func : function
- A Python function object to apply to each column or row, or a Python function decorated with @udf.
+ A Python function object to apply to each column or row.
axis : {0 or 'index', 1 or 'columns'}, default 0
Axis along which the function is applied:
@@ -738,8 +738,6 @@ def apply():
* 0 or 'index': apply function to each column.
* 1 or 'columns': apply function to each row.
- Snowpark pandas does not yet support ``axis=0``.
-
raw : bool, default False
Determines if row or column is passed as a Series or ndarray object:
@@ -810,8 +808,6 @@ def apply():
7. When ``func`` uses any first-party modules or third-party packages inside the function,
you need to add these dependencies via ``session.add_import()`` and ``session.add_packages()``.
- Alternatively. specify third-party packages with the @udf decorator. When using the @udf decorator,
- annotations using PandasSeriesType or PandasDataFrameType are not supported.
8. The Snowpark pandas module cannot currently be referenced inside the definition of
``func``. If you need to call a general pandas API like ``pd.Timestamp`` inside ``func``,
@@ -852,22 +848,6 @@ def apply():
1 14.50
2 24.25
dtype: float64
-
- or annotate the function
- with the @udf decorator from Snowpark https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.functions.udf.
-
- >>> from snowflake.snowpark.functions import udf
- >>> from snowflake.snowpark.types import DoubleType
- >>> @udf(packages=['statsmodels>0.12'], return_type=DoubleType())
- ... def autocorr(column):
- ... import pandas as pd
- ... import statsmodels.tsa.stattools
- ... return pd.Series(statsmodels.tsa.stattools.pacf_ols(column.values)).mean()
- ...
- >>> df.apply(autocorr, axis=0) # doctest: +SKIP
- A 0.857143
- B 0.428571
- dtype: float64
"""
def assign():
@@ -1061,8 +1041,6 @@ def transform():
axis : {0 or 'index', 1 or 'columns'}, default 0
If 0 or 'index': apply function to each column. If 1 or 'columns': apply function to each row.
- Snowpark pandas currently only supports axis=1, and does not yet support axis=0.
-
*args
Positional arguments to pass to `func`.
@@ -1771,7 +1749,7 @@ def info():
... 'COL2': ['A', 'B', 'C']})
>>> df.info() # doctest: +NORMALIZE_WHITESPACE
-
+
SnowflakeIndex
Data columns (total 2 columns):
# Column Non-Null Count Dtype
diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py
index 1d351fd67af..9e4ebd4d257 100644
--- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py
+++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py
@@ -3428,7 +3428,7 @@ def unique():
>>> pd.Series([pd.Timestamp('2016-01-01', tz='US/Eastern')
... for _ in range(3)]).unique()
- array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')],
+ array([Timestamp('2016-01-01 00:00:00-0500', tz='UTC-05:00')],
dtype=object)
"""
diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py
index 88c4029a92c..b05d7d76db6 100644
--- a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py
+++ b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py
@@ -1858,10 +1858,181 @@ def to_pydatetime():
pass
def tz_localize():
- pass
+ """
+ Localize tz-naive Datetime Array/Index to tz-aware Datetime Array/Index.
+
+ This method takes a time zone (tz) naive Datetime Array/Index object and makes this time zone aware. It does not move the time to another time zone.
+
+ This method can also be used to do the inverse – to create a time zone unaware object from an aware object. To that end, pass tz=None.
+
+ Parameters
+ ----------
+ tz : str, pytz.timezone, dateutil.tz.tzfile, datetime.tzinfo or None
+ Time zone to convert timestamps to. Passing None will remove the time zone information preserving local time.
+ ambiguous : ‘infer’, ‘NaT’, bool array, default ‘raise’
+ When clocks moved backward due to DST, ambiguous times may arise. For example in Central European Time (UTC+01), when going from 03:00 DST to 02:00 non-DST, 02:30:00 local time occurs both at 00:30:00 UTC and at 01:30:00 UTC. In such a situation, the ambiguous parameter dictates how ambiguous times should be handled.
+ - ‘infer’ will attempt to infer fall dst-transition hours based on order
+ - bool-ndarray where True signifies a DST time, False signifies a non-DST time (note that this flag is only applicable for ambiguous times)
+ - ‘NaT’ will return NaT where there are ambiguous times
+ - ‘raise’ will raise an AmbiguousTimeError if there are ambiguous times.
+ nonexistent : ‘shift_forward’, ‘shift_backward, ‘NaT’, timedelta, default ‘raise’
+ A nonexistent time does not exist in a particular timezone where clocks moved forward due to DST.
+ - ‘shift_forward’ will shift the nonexistent time forward to the closest existing time
+ - ‘shift_backward’ will shift the nonexistent time backward to the closest existing time
+ - ‘NaT’ will return NaT where there are nonexistent times
+ - timedelta objects will shift nonexistent times by the timedelta
+ - ‘raise’ will raise an NonExistentTimeError if there are nonexistent times.
+
+ Returns
+ -------
+ Same type as self
+ Array/Index converted to the specified time zone.
+
+ Raises
+ ------
+ TypeError
+ If the Datetime Array/Index is tz-aware and tz is not None.
+
+ See also
+ --------
+ DatetimeIndex.tz_convert
+ Convert tz-aware DatetimeIndex from one time zone to another.
+
+ Examples
+ --------
+ >>> tz_naive = pd.date_range('2018-03-01 09:00', periods=3)
+ >>> tz_naive
+ DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00',
+ '2018-03-03 09:00:00'],
+ dtype='datetime64[ns]', freq=None)
+
+ Localize DatetimeIndex in US/Eastern time zone:
+
+ >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP
+ >>> tz_aware # doctest: +SKIP
+ DatetimeIndex(['2018-03-01 09:00:00-05:00',
+ '2018-03-02 09:00:00-05:00',
+ '2018-03-03 09:00:00-05:00'],
+ dtype='datetime64[ns, US/Eastern]', freq=None)
+
+ With the tz=None, we can remove the time zone information while keeping the local time (not converted to UTC):
+
+ >>> tz_aware.tz_localize(None) # doctest: +SKIP
+ DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00',
+ '2018-03-03 09:00:00'],
+ dtype='datetime64[ns]', freq=None)
+
+ Be careful with DST changes. When there is sequential data, pandas can infer the DST time:
+
+ >>> s = pd.to_datetime(pd.Series(['2018-10-28 01:30:00',
+ ... '2018-10-28 02:00:00',
+ ... '2018-10-28 02:30:00',
+ ... '2018-10-28 02:00:00',
+ ... '2018-10-28 02:30:00',
+ ... '2018-10-28 03:00:00',
+ ... '2018-10-28 03:30:00']))
+ >>> s.dt.tz_localize('CET', ambiguous='infer') # doctest: +SKIP
+ 0 2018-10-28 01:30:00+02:00
+ 1 2018-10-28 02:00:00+02:00
+ 2 2018-10-28 02:30:00+02:00
+ 3 2018-10-28 02:00:00+01:00
+ 4 2018-10-28 02:30:00+01:00
+ 5 2018-10-28 03:00:00+01:00
+ 6 2018-10-28 03:30:00+01:00
+ dtype: datetime64[ns, CET]
+
+ In some cases, inferring the DST is impossible. In such cases, you can pass an ndarray to the ambiguous parameter to set the DST explicitly
+
+ >>> s = pd.to_datetime(pd.Series(['2018-10-28 01:20:00',
+ ... '2018-10-28 02:36:00',
+ ... '2018-10-28 03:46:00']))
+ >>> s.dt.tz_localize('CET', ambiguous=np.array([True, True, False])) # doctest: +SKIP
+ 0 2018-10-28 01:20:00+02:00
+ 1 2018-10-28 02:36:00+02:00
+ 2 2018-10-28 03:46:00+01:00
+ dtype: datetime64[ns, CET]
+
+ If the DST transition causes nonexistent times, you can shift these dates forward or backwards with a timedelta object or ‘shift_forward’ or ‘shift_backwards’.
+
+ >>> s = pd.to_datetime(pd.Series(['2015-03-29 02:30:00',
+ ... '2015-03-29 03:30:00']))
+ >>> s.dt.tz_localize('Europe/Warsaw', nonexistent='shift_forward') # doctest: +SKIP
+ 0 2015-03-29 03:00:00+02:00
+ 1 2015-03-29 03:30:00+02:00
+ dtype: datetime64[ns, Europe/Warsaw]
+
+ >>> s.dt.tz_localize('Europe/Warsaw', nonexistent='shift_backward') # doctest: +SKIP
+ 0 2015-03-29 01:59:59.999999999+01:00
+ 1 2015-03-29 03:30:00+02:00
+ dtype: datetime64[ns, Europe/Warsaw]
+
+ >>> s.dt.tz_localize('Europe/Warsaw', nonexistent=pd.Timedelta('1h')) # doctest: +SKIP
+ 0 2015-03-29 03:30:00+02:00
+ 1 2015-03-29 03:30:00+02:00
+ dtype: datetime64[ns, Europe/Warsaw]
+ """
def tz_convert():
- pass
+ """
+ Convert tz-aware Datetime Array/Index from one time zone to another.
+
+ Parameters
+ ----------
+ tz : str, pytz.timezone, dateutil.tz.tzfile, datetime.tzinfo or None
+ Time zone for time. Corresponding timestamps would be converted to this time zone of the Datetime Array/Index. A tz of None will convert to UTC and remove the timezone information.
+
+ Returns
+ -------
+ Array or Index
+
+ Raises
+ ------
+ TypeError
+ If Datetime Array/Index is tz-naive.
+
+ See also
+ DatetimeIndex.tz
+ A timezone that has a variable offset from UTC.
+ DatetimeIndex.tz_localize
+ Localize tz-naive DatetimeIndex to a given time zone, or remove timezone from a tz-aware DatetimeIndex.
+
+ Examples
+ --------
+ With the tz parameter, we can change the DatetimeIndex to other time zones:
+
+ >>> dti = pd.date_range(start='2014-08-01 09:00',
+ ... freq='h', periods=3, tz='Europe/Berlin') # doctest: +SKIP
+
+ >>> dti # doctest: +SKIP
+ DatetimeIndex(['2014-08-01 09:00:00+02:00',
+ '2014-08-01 10:00:00+02:00',
+ '2014-08-01 11:00:00+02:00'],
+ dtype='datetime64[ns, Europe/Berlin]', freq='h')
+
+ >>> dti.tz_convert('US/Central') # doctest: +SKIP
+ DatetimeIndex(['2014-08-01 02:00:00-05:00',
+ '2014-08-01 03:00:00-05:00',
+ '2014-08-01 04:00:00-05:00'],
+ dtype='datetime64[ns, US/Central]', freq='h')
+
+ With the tz=None, we can remove the timezone (after converting to UTC if necessary):
+
+ >>> dti = pd.date_range(start='2014-08-01 09:00', freq='h',
+ ... periods=3, tz='Europe/Berlin') # doctest: +SKIP
+
+ >>> dti # doctest: +SKIP
+ DatetimeIndex(['2014-08-01 09:00:00+02:00',
+ '2014-08-01 10:00:00+02:00',
+ '2014-08-01 11:00:00+02:00'],
+ dtype='datetime64[ns, Europe/Berlin]', freq='h')
+
+ >>> dti.tz_convert(None) # doctest: +SKIP
+ DatetimeIndex(['2014-08-01 07:00:00',
+ '2014-08-01 08:00:00',
+ '2014-08-01 09:00:00'],
+ dtype='datetime64[ns]', freq='h')
+ """
+ # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests.
def normalize():
pass
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
index aeca9d6e305..ecef6e843ba 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
@@ -60,7 +60,6 @@
validate_percentile,
)
-import snowflake.snowpark.modin.pandas as spd
from snowflake.snowpark.modin.pandas.api.extensions import (
register_dataframe_accessor,
register_series_accessor,
@@ -88,8 +87,6 @@ def register_base_override(method_name: str):
for directly overriding methods on BasePandasDataset, we mock this by performing the override on
DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary
to allow both the docstring extension and method dispatch to work properly.
-
- Methods annotated here also are automatically instrumented with Snowpark pandas telemetry.
"""
def decorator(base_method: Any):
@@ -103,10 +100,7 @@ def decorator(base_method: Any):
series_method = series_method.fget
if series_method is None or series_method is parent_method:
register_series_accessor(method_name)(base_method)
- # TODO: SNOW-1063346
- # Since we still use the vendored version of DataFrame and the overrides for the top-level
- # namespace haven't been performed yet, we need to set properties on the vendored version
- df_method = getattr(spd.dataframe.DataFrame, method_name, None)
+ df_method = getattr(pd.DataFrame, method_name, None)
if isinstance(df_method, property):
df_method = df_method.fget
if df_method is None or df_method is parent_method:
@@ -176,6 +170,22 @@ def filter(
pass # pragma: no cover
+@register_base_not_implemented()
+def interpolate(
+ self,
+ method="linear",
+ *,
+ axis=0,
+ limit=None,
+ inplace=False,
+ limit_direction: str | None = None,
+ limit_area=None,
+ downcast=lib.no_default,
+ **kwargs,
+): # noqa: PR01, RT01, D200
+ pass
+
+
@register_base_not_implemented()
def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200
pass # pragma: no cover
@@ -813,7 +823,7 @@ def _binary_op(
**kwargs,
)
- from snowflake.snowpark.modin.pandas.dataframe import DataFrame
+ from modin.pandas.dataframe import DataFrame
# Modin Bug: https://github.com/modin-project/modin/issues/7236
# For a Series interacting with a DataFrame, always return a DataFrame
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py
index 5ce836061ab..62c9cab4dc1 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_overrides.py
@@ -7,20 +7,1443 @@
pandas, such as `DataFrame.memory_usage`.
"""
-from typing import Any, Union
+from __future__ import annotations
+import collections
+import datetime
+import functools
+import itertools
+import sys
+import warnings
+from typing import (
+ IO,
+ Any,
+ Callable,
+ Hashable,
+ Iterable,
+ Iterator,
+ Literal,
+ Mapping,
+ Sequence,
+)
+
+import modin.pandas as pd
+import numpy as np
import pandas as native_pd
-from modin.pandas import DataFrame
-from pandas._typing import Axis, PythonFuncType
-from pandas.core.dtypes.common import is_dict_like, is_list_like
+from modin.pandas import DataFrame, Series
+from modin.pandas.base import BasePandasDataset
+from pandas._libs.lib import NoDefault, no_default
+from pandas._typing import (
+ AggFuncType,
+ AnyArrayLike,
+ Axes,
+ Axis,
+ CompressionOptions,
+ FilePath,
+ FillnaOptions,
+ IgnoreRaise,
+ IndexLabel,
+ Level,
+ PythonFuncType,
+ Renamer,
+ Scalar,
+ StorageOptions,
+ Suffixes,
+ WriteBuffer,
+)
+from pandas.core.common import apply_if_callable, is_bool_indexer
+from pandas.core.dtypes.common import (
+ infer_dtype_from_object,
+ is_bool_dtype,
+ is_dict_like,
+ is_list_like,
+ is_numeric_dtype,
+)
+from pandas.core.dtypes.inference import is_hashable, is_integer
+from pandas.core.indexes.frozen import FrozenList
+from pandas.io.formats.printing import pprint_thing
+from pandas.util._validators import validate_bool_kwarg
+
+from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor
+from snowflake.snowpark.modin.pandas.groupby import (
+ DataFrameGroupBy,
+ validate_groupby_args,
+)
+from snowflake.snowpark.modin.pandas.snow_partition_iterator import (
+ SnowparkPandasRowPartitionIterator,
+)
+from snowflake.snowpark.modin.pandas.utils import (
+ create_empty_native_pandas_frame,
+ from_non_pandas,
+ from_pandas,
+ is_scalar,
+ raise_if_native_pandas_objects,
+ replace_external_data_keys_with_empty_pandas_series,
+ replace_external_data_keys_with_query_compiler,
+)
+from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
+ is_snowflake_agg_func,
+)
+from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated
+from snowflake.snowpark.modin.plugin._typing import ListLike
+from snowflake.snowpark.modin.plugin.utils.error_message import (
+ ErrorMessage,
+ dataframe_not_implemented,
+)
+from snowflake.snowpark.modin.plugin.utils.frontend_constants import (
+ DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE,
+ DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE,
+ DF_SETITEM_SLICE_AS_SCALAR_VALUE,
+)
+from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
+from snowflake.snowpark.modin.utils import (
+ _inherit_docstrings,
+ hashable,
+ validate_int_kwarg,
+)
+from snowflake.snowpark.udf import UserDefinedFunction
+
+
+def register_dataframe_not_implemented():
+ def decorator(base_method: Any):
+ func = dataframe_not_implemented()(base_method)
+ register_dataframe_accessor(base_method.__name__)(func)
+ return func
+
+ return decorator
+
+
+# === UNIMPLEMENTED METHODS ===
+# The following methods are not implemented in Snowpark pandas, and must be overridden on the
+# frontend. These methods fall into a few categories:
+# 1. Would work in Snowpark pandas, but we have not tested it.
+# 2. Would work in Snowpark pandas, but requires more SQL queries than we are comfortable with.
+# 3. Requires materialization (usually via a frontend _default_to_pandas call).
+# 4. Performs operations on a native pandas Index object that are nontrivial for Snowpark pandas to manage.
+
+
+# Avoid overwriting builtin `map` by accident
+@register_dataframe_accessor("map")
+@dataframe_not_implemented()
+def _map(self, func, na_action: str | None = None, **kwargs) -> DataFrame:
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def boxplot(
+ self,
+ column=None,
+ by=None,
+ ax=None,
+ fontsize=None,
+ rot=0,
+ grid=True,
+ figsize=None,
+ layout=None,
+ return_type=None,
+ backend=None,
+ **kwargs,
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def combine(
+ self, other, func, fill_value=None, overwrite=True
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def corrwith(
+ self, other, axis=0, drop=False, method="pearson", numeric_only=False
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def cov(
+ self, min_periods=None, ddof: int | None = 1, numeric_only=False
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def dot(self, other): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def hist(
+ self,
+ column=None,
+ by=None,
+ grid=True,
+ xlabelsize=None,
+ xrot=None,
+ ylabelsize=None,
+ yrot=None,
+ ax=None,
+ sharex=False,
+ sharey=False,
+ figsize=None,
+ layout=None,
+ bins=10,
+ **kwds,
+):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def isetitem(self, loc, value):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def prod(
+ self,
+ axis=None,
+ skipna=True,
+ numeric_only=False,
+ min_count=0,
+ **kwargs,
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+register_dataframe_accessor("product")(prod)
+
+
+@register_dataframe_not_implemented()
+def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def reindex_like(
+ self,
+ other,
+ method=None,
+ copy: bool | None = None,
+ limit=None,
+ tolerance=None,
+) -> DataFrame: # pragma: no cover
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_gbq(
+ self,
+ destination_table,
+ project_id=None,
+ chunksize=None,
+ reauth=False,
+ if_exists="fail",
+ auth_local_webserver=True,
+ table_schema=None,
+ location=None,
+ progress_bar=True,
+ credentials=None,
+): # pragma: no cover # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_html(
+ self,
+ buf=None,
+ columns=None,
+ col_space=None,
+ header=True,
+ index=True,
+ na_rep="NaN",
+ formatters=None,
+ float_format=None,
+ sparsify=None,
+ index_names=True,
+ justify=None,
+ max_rows=None,
+ max_cols=None,
+ show_dimensions=False,
+ decimal=".",
+ bold_rows=True,
+ classes=None,
+ escape=True,
+ notebook=False,
+ border=None,
+ table_id=None,
+ render_links=False,
+ encoding=None,
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_parquet(
+ self,
+ path=None,
+ engine="auto",
+ compression="snappy",
+ index=None,
+ partition_cols=None,
+ storage_options: StorageOptions = None,
+ **kwargs,
+):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_period(
+ self, freq=None, axis=0, copy=True
+): # pragma: no cover # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_records(
+ self, index=True, column_dtypes=None, index_dtypes=None
+): # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_stata(
+ self,
+ path: FilePath | WriteBuffer[bytes],
+ convert_dates: dict[Hashable, str] | None = None,
+ write_index: bool = True,
+ byteorder: str | None = None,
+ time_stamp: datetime.datetime | None = None,
+ data_label: str | None = None,
+ variable_labels: dict[Hashable, str] | None = None,
+ version: int | None = 114,
+ convert_strl: Sequence[Hashable] | None = None,
+ compression: CompressionOptions = "infer",
+ storage_options: StorageOptions = None,
+ *,
+ value_labels: dict[Hashable, dict[float | int, str]] | None = None,
+):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def to_xml(
+ self,
+ path_or_buffer=None,
+ index=True,
+ root_name="data",
+ row_name="row",
+ na_rep=None,
+ attr_cols=None,
+ elem_cols=None,
+ namespaces=None,
+ prefix=None,
+ encoding="utf-8",
+ xml_declaration=True,
+ pretty_print=True,
+ parser="lxml",
+ stylesheet=None,
+ compression="infer",
+ storage_options=None,
+):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def __delitem__(self, key):
+ pass # pragma: no cover
+
+
+@register_dataframe_accessor("attrs")
+@dataframe_not_implemented()
+@property
+def attrs(self): # noqa: RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_accessor("style")
+@dataframe_not_implemented()
+@property
+def style(self): # noqa: RT01, D200
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def __reduce__(self):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def __divmod__(self, other):
+ pass # pragma: no cover
+
+
+@register_dataframe_not_implemented()
+def __rdivmod__(self, other):
+ pass # pragma: no cover
+
+
+# The from_dict and from_records accessors are class methods and cannot be overridden via the
+# extensions module, as they need to be foisted onto the namespace directly because they are not
+# routed through getattr. To this end, we manually set DataFrame.from_dict to our new method.
+@dataframe_not_implemented()
+def from_dict(
+ cls, data, orient="columns", dtype=None, columns=None
+): # pragma: no cover # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+DataFrame.from_dict = from_dict
+
+
+@dataframe_not_implemented()
+def from_records(
+ cls,
+ data,
+ index=None,
+ exclude=None,
+ columns=None,
+ coerce_float=False,
+ nrows=None,
+): # pragma: no cover # noqa: PR01, RT01, D200
+ pass # pragma: no cover
+
+
+DataFrame.from_records = from_records
+
+
+# === OVERRIDDEN METHODS ===
+# The below methods have their frontend implementations overridden compared to the version present
+# in series.py. This is usually for one of the following reasons:
+# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate
+# and binary operations; further work is needed to refactor either our implementation or upstream
+# modin's implementation.
+# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already
+# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details.
+# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler
+# layer is acceptable because we can force the method to raise NotImplementedError, but if a method
+# defaults at the frontend, Modin raises a warning and performs the operation by coercing the
+# dataset to a native pandas object. Removing these is tracked by
+# https://github.com/modin-project/modin/issues/7104
+
+
+# Snowpark pandas overrides the constructor for two reasons:
+# 1. To support the Snowpark pandas lazy index object
+# 2. To avoid raising "UserWarning: Distributing object. This may take some time."
+# when a literal is passed in as data.
+@register_dataframe_accessor("__init__")
+def __init__(
+ self,
+ data=None,
+ index=None,
+ columns=None,
+ dtype=None,
+ copy=None,
+ query_compiler=None,
+) -> None:
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ # Siblings are other dataframes that share the same query compiler. We
+ # use this list to update inplace when there is a shallow copy.
+ from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
+
+ self._siblings = []
+
+ # Engine.subscribe(_update_engine)
+ if isinstance(data, (DataFrame, Series)):
+ self._query_compiler = data._query_compiler.copy()
+ if index is not None and any(i not in data.index for i in index):
+ ErrorMessage.not_implemented(
+ "Passing non-existant columns or index values to constructor not"
+ + " yet implemented."
+ ) # pragma: no cover
+ if isinstance(data, Series):
+ # We set the column name if it is not in the provided Series
+ if data.name is None:
+ self.columns = [0] if columns is None else columns
+ # If the columns provided are not in the named Series, pandas clears
+ # the DataFrame and sets columns to the columns provided.
+ elif columns is not None and data.name not in columns:
+ self._query_compiler = from_pandas(
+ self.__constructor__(columns=columns)
+ )._query_compiler
+ if index is not None:
+ self._query_compiler = data.loc[index]._query_compiler
+ elif columns is None and index is None:
+ data._add_sibling(self)
+ else:
+ if columns is not None and any(i not in data.columns for i in columns):
+ ErrorMessage.not_implemented(
+ "Passing non-existant columns or index values to constructor not"
+ + " yet implemented."
+ ) # pragma: no cover
+ if index is None:
+ index = slice(None)
+ if columns is None:
+ columns = slice(None)
+ self._query_compiler = data.loc[index, columns]._query_compiler
+
+ # Check type of data and use appropriate constructor
+ elif query_compiler is None:
+ distributed_frame = from_non_pandas(data, index, columns, dtype)
+ if distributed_frame is not None:
+ self._query_compiler = distributed_frame._query_compiler
+ return
+
+ if isinstance(data, native_pd.Index):
+ pass
+ elif is_list_like(data) and not is_dict_like(data):
+ old_dtype = getattr(data, "dtype", None)
+ values = [
+ obj._to_pandas() if isinstance(obj, Series) else obj for obj in data
+ ]
+ if isinstance(data, np.ndarray):
+ data = np.array(values, dtype=old_dtype)
+ else:
+ try:
+ data = type(data)(values, dtype=old_dtype)
+ except TypeError:
+ data = values
+ elif is_dict_like(data) and not isinstance(
+ data, (native_pd.Series, Series, native_pd.DataFrame, DataFrame)
+ ):
+ if columns is not None:
+ data = {key: value for key, value in data.items() if key in columns}
+
+ if len(data) and all(isinstance(v, Series) for v in data.values()):
+ from modin.pandas import concat
+
+ new_qc = concat(data.values(), axis=1, keys=data.keys())._query_compiler
+
+ if dtype is not None:
+ new_qc = new_qc.astype({col: dtype for col in new_qc.columns})
+ if index is not None:
+ new_qc = new_qc.reindex(
+ axis=0, labels=try_convert_index_to_native(index)
+ )
+ if columns is not None:
+ new_qc = new_qc.reindex(
+ axis=1, labels=try_convert_index_to_native(columns)
+ )
+
+ self._query_compiler = new_qc
+ return
+
+ data = {
+ k: v._to_pandas() if isinstance(v, Series) else v
+ for k, v in data.items()
+ }
+ pandas_df = native_pd.DataFrame(
+ data=try_convert_index_to_native(data),
+ index=try_convert_index_to_native(index),
+ columns=try_convert_index_to_native(columns),
+ dtype=dtype,
+ copy=copy,
+ )
+ self._query_compiler = from_pandas(pandas_df)._query_compiler
+ else:
+ self._query_compiler = query_compiler
+
+
+@register_dataframe_accessor("__dataframe__")
+def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True):
+ """
+ Get a Modin DataFrame that implements the dataframe exchange protocol.
+
+ See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html.
+
+ Parameters
+ ----------
+ nan_as_null : bool, default: False
+ A keyword intended for the consumer to tell the producer
+ to overwrite null values in the data with ``NaN`` (or ``NaT``).
+ This currently has no effect; once support for nullable extension
+ dtypes is added, this value should be propagated to columns.
+ allow_copy : bool, default: True
+ A keyword that defines whether or not the library is allowed
+ to make a copy of the data. For example, copying data would be necessary
+ if a library supports strided buffers, given that this protocol
+ specifies contiguous buffers. Currently, if the flag is set to ``False``
+ and a copy is needed, a ``RuntimeError`` will be raised.
+
+ Returns
+ -------
+ ProtocolDataframe
+ A dataframe object following the dataframe protocol specification.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ ErrorMessage.not_implemented(
+ "Snowpark pandas does not support the DataFrame interchange "
+ + "protocol method `__dataframe__`. To use Snowpark pandas "
+ + "DataFrames with third-party libraries that try to call the "
+ + "`__dataframe__` method, please convert this Snowpark pandas "
+ + "DataFrame to pandas with `to_pandas()`."
+ )
+
+ return self._query_compiler.to_dataframe(
+ nan_as_null=nan_as_null, allow_copy=allow_copy
+ )
+
+
+# Snowpark pandas defaults to axis=1 instead of axis=0 for these; we need to investigate if the same should
+# apply to upstream Modin.
+@register_dataframe_accessor("__and__")
+def __and__(self, other):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._binary_op("__and__", other, axis=1)
+
+
+@register_dataframe_accessor("__rand__")
+def __rand__(self, other):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._binary_op("__rand__", other, axis=1)
+
+
+@register_dataframe_accessor("__or__")
+def __or__(self, other):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._binary_op("__or__", other, axis=1)
+
+
+@register_dataframe_accessor("__ror__")
+def __ror__(self, other):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._binary_op("__ror__", other, axis=1)
+
+
+# Upstream Modin defaults to pandas in some cases.
+@register_dataframe_accessor("apply")
+def apply(
+ self,
+ func: AggFuncType | UserDefinedFunction,
+ axis: Axis = 0,
+ raw: bool = False,
+ result_type: Literal["expand", "reduce", "broadcast"] | None = None,
+ args=(),
+ **kwargs,
+):
+ """
+ Apply a function along an axis of the ``DataFrame``.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ axis = self._get_axis_number(axis)
+ query_compiler = self._query_compiler.apply(
+ func,
+ axis,
+ raw=raw,
+ result_type=result_type,
+ args=args,
+ **kwargs,
+ )
+ if not isinstance(query_compiler, type(self._query_compiler)):
+ # A scalar was returned
+ return query_compiler
+
+ # If True, it is an unamed series.
+ # Theoretically, if df.apply returns a Series, it will only be an unnamed series
+ # because the function is supposed to be series -> scalar.
+ if query_compiler._modin_frame.is_unnamed_series():
+ return Series(query_compiler=query_compiler)
+ else:
+ return self.__constructor__(query_compiler=query_compiler)
+
+
+# Snowpark pandas uses a separate QC method, while modin directly calls map.
+@register_dataframe_accessor("applymap")
+def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if not callable(func):
+ raise TypeError(f"{func} is not callable")
+ return self.__constructor__(
+ query_compiler=self._query_compiler.applymap(
+ func, na_action=na_action, **kwargs
+ )
+ )
+
+
+# We need to override _get_columns to satisfy
+# tests/unit/modin/test_type_annotations.py::test_properties_snow_1374293[_get_columns-type_hints1]
+# since Modin doesn't provide this type hint.
+def _get_columns(self) -> native_pd.Index:
+ """
+ Get the columns for this Snowpark pandas ``DataFrame``.
+
+ Returns
+ -------
+ Index
+ The all columns.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._query_compiler.columns
+
+
+# Snowpark pandas wraps this in an update_in_place
+def _set_columns(self, new_columns: Axes) -> None:
+ """
+ Set the columns for this Snowpark pandas ``DataFrame``.
+
+ Parameters
+ ----------
+ new_columns :
+ The new columns to set.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ self._update_inplace(
+ new_query_compiler=self._query_compiler.set_columns(new_columns)
+ )
+
+
+register_dataframe_accessor("columns")(property(_get_columns, _set_columns))
+
+
+# Snowpark pandas does preprocessing for numeric_only (should be pushed to QC).
+@register_dataframe_accessor("corr")
+def corr(
+ self,
+ method: str | Callable = "pearson",
+ min_periods: int | None = None,
+ numeric_only: bool = False,
+): # noqa: PR01, RT01, D200
+ """
+ Compute pairwise correlation of columns, excluding NA/null values.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ corr_df = self
+ if numeric_only:
+ corr_df = self.drop(
+ columns=[
+ i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i])
+ ]
+ )
+ return self.__constructor__(
+ query_compiler=corr_df._query_compiler.corr(
+ method=method,
+ min_periods=min_periods,
+ )
+ )
+
+
+# Snowpark pandas does not respect `ignore_index`, and upstream Modin does not respect `how`.
+@register_dataframe_accessor("dropna")
+def dropna(
+ self,
+ *,
+ axis: Axis = 0,
+ how: str | NoDefault = no_default,
+ thresh: int | NoDefault = no_default,
+ subset: IndexLabel = None,
+ inplace: bool = False,
+): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return super(DataFrame, self)._dropna(
+ axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace
+ )
+
+
+# Snowpark pandas uses `self_is_series`, while upstream Modin uses `squeeze_self` and `squeeze_value`.
+@register_dataframe_accessor("fillna")
+def fillna(
+ self,
+ value: Hashable | Mapping | Series | DataFrame = None,
+ *,
+ method: FillnaOptions | None = None,
+ axis: Axis | None = None,
+ inplace: bool = False,
+ limit: int | None = None,
+ downcast: dict | None = None,
+) -> DataFrame | None:
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return super(DataFrame, self).fillna(
+ self_is_series=False,
+ value=value,
+ method=method,
+ axis=axis,
+ inplace=inplace,
+ limit=limit,
+ downcast=downcast,
+ )
+
+
+# Snowpark pandas does different validation and returns a custom GroupBy object.
+@register_dataframe_accessor("groupby")
+def groupby(
+ self,
+ by=None,
+ axis: Axis | NoDefault = no_default,
+ level: IndexLabel | None = None,
+ as_index: bool = True,
+ sort: bool = True,
+ group_keys: bool = True,
+ observed: bool | NoDefault = no_default,
+ dropna: bool = True,
+):
+ """
+ Group ``DataFrame`` using a mapper or by a ``Series`` of columns.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if axis is not no_default:
+ axis = self._get_axis_number(axis)
+ if axis == 1:
+ warnings.warn(
+ "DataFrame.groupby with axis=1 is deprecated. Do "
+ + "`frame.T.groupby(...)` without axis instead.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ warnings.warn(
+ "The 'axis' keyword in DataFrame.groupby is deprecated and "
+ + "will be removed in a future version.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ axis = 0
+
+ validate_groupby_args(by, level, observed)
+
+ axis = self._get_axis_number(axis)
+
+ if axis != 0 and as_index is False:
+ raise ValueError("as_index=False only valid for axis=0")
+
+ idx_name = None
+
+ if (
+ not isinstance(by, Series)
+ and is_list_like(by)
+ and len(by) == 1
+ # if by is a list-like of (None,), we have to keep it as a list because
+ # None may be referencing a column or index level whose label is
+ # `None`, and by=None wold mean that there is no `by` param.
+ and by[0] is not None
+ ):
+ by = by[0]
+
+ if hashable(by) and (
+ not callable(by) and not isinstance(by, (native_pd.Grouper, FrozenList))
+ ):
+ idx_name = by
+ elif isinstance(by, Series):
+ idx_name = by.name
+ if by._parent is self:
+ # if the SnowSeries comes from the current dataframe,
+ # convert it to labels directly for easy processing
+ by = by.name
+ elif is_list_like(by):
+ if axis == 0 and all(
+ (
+ (hashable(o) and (o in self))
+ or isinstance(o, Series)
+ or (is_list_like(o) and len(o) == len(self.shape[axis]))
+ )
+ for o in by
+ ):
+ # plit 'by's into those that belongs to the self (internal_by)
+ # and those that doesn't (external_by). For SnowSeries that belongs
+ # to current DataFrame, we convert it to labels for easy process.
+ internal_by, external_by = [], []
+
+ for current_by in by:
+ if hashable(current_by):
+ internal_by.append(current_by)
+ elif isinstance(current_by, Series):
+ if current_by._parent is self:
+ internal_by.append(current_by.name)
+ else:
+ external_by.append(current_by) # pragma: no cover
+ else:
+ external_by.append(current_by)
+
+ by = internal_by + external_by
+
+ return DataFrameGroupBy(
+ self,
+ by,
+ axis,
+ level,
+ as_index,
+ sort,
+ group_keys,
+ idx_name,
+ observed=observed,
+ dropna=dropna,
+ )
+
+
+# Upstream Modin uses a proxy DataFrameInfo object
+@register_dataframe_accessor("info")
+def info(
+ self,
+ verbose: bool | None = None,
+ buf: IO[str] | None = None,
+ max_cols: int | None = None,
+ memory_usage: bool | str | None = None,
+ show_counts: bool | None = None,
+ null_counts: bool | None = None,
+): # noqa: PR01, D200
+ """
+ Print a concise summary of the ``DataFrame``.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ def put_str(src, output_len=None, spaces=2):
+ src = str(src)
+ return src.ljust(output_len if output_len else len(src)) + " " * spaces
+
+ def format_size(num):
+ for x in ["bytes", "KB", "MB", "GB", "TB"]:
+ if num < 1024.0:
+ return f"{num:3.1f} {x}"
+ num /= 1024.0
+ return f"{num:3.1f} PB"
+
+ output = []
+
+ type_line = str(type(self))
+ index_line = "SnowflakeIndex"
+ columns = self.columns
+ columns_len = len(columns)
+ dtypes = self.dtypes
+ dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}"
+
+ if max_cols is None:
+ max_cols = 100
+
+ exceeds_info_cols = columns_len > max_cols
+
+ if buf is None:
+ buf = sys.stdout
+
+ if null_counts is None:
+ null_counts = not exceeds_info_cols
+
+ if verbose is None:
+ verbose = not exceeds_info_cols
+
+ if null_counts and verbose:
+ # We're gonna take items from `non_null_count` in a loop, which
+ # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here
+ # that will be faster.
+ non_null_count = self.count()._to_pandas()
+
+ if memory_usage is None:
+ memory_usage = True
+
+ def get_header(spaces=2):
+ output = []
+ head_label = " # "
+ column_label = "Column"
+ null_label = "Non-Null Count"
+ dtype_label = "Dtype"
+ non_null_label = " non-null"
+ delimiter = "-"
+
+ lengths = {}
+ lengths["head"] = max(len(head_label), len(pprint_thing(len(columns))))
+ lengths["column"] = max(
+ len(column_label), max(len(pprint_thing(col)) for col in columns)
+ )
+ lengths["dtype"] = len(dtype_label)
+ dtype_spaces = (
+ max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes))
+ - lengths["dtype"]
+ )
+
+ header = put_str(head_label, lengths["head"]) + put_str(
+ column_label, lengths["column"]
+ )
+ if null_counts:
+ lengths["null"] = max(
+ len(null_label),
+ max(len(pprint_thing(x)) for x in non_null_count) + len(non_null_label),
+ )
+ header += put_str(null_label, lengths["null"])
+ header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces)
+
+ output.append(header)
+
+ delimiters = put_str(delimiter * lengths["head"]) + put_str(
+ delimiter * lengths["column"]
+ )
+ if null_counts:
+ delimiters += put_str(delimiter * lengths["null"])
+ delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces)
+ output.append(delimiters)
+
+ return output, lengths
+
+ output.extend([type_line, index_line])
+
+ def verbose_repr(output):
+ columns_line = f"Data columns (total {len(columns)} columns):"
+ header, lengths = get_header()
+ output.extend([columns_line, *header])
+ for i, col in enumerate(columns):
+ i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]])
+
+ to_append = put_str(f" {i}", lengths["head"]) + put_str(
+ col_s, lengths["column"]
+ )
+ if null_counts:
+ non_null = pprint_thing(non_null_count[col])
+ to_append += put_str(f"{non_null} non-null", lengths["null"])
+ to_append += put_str(dtype, lengths["dtype"], spaces=0)
+ output.append(to_append)
+
+ def non_verbose_repr(output):
+ output.append(columns._summary(name="Columns"))
+
+ if verbose:
+ verbose_repr(output)
+ else:
+ non_verbose_repr(output)
+
+ output.append(dtypes_line)
+
+ if memory_usage:
+ deep = memory_usage == "deep"
+ mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum()
+ mem_line = f"memory usage: {format_size(mem_usage_bytes)}"
+
+ output.append(mem_line)
+
+ output.append("")
+ buf.write("\n".join(output))
+
+
+# Snowpark pandas does different validation.
+@register_dataframe_accessor("insert")
+def insert(
+ self,
+ loc: int,
+ column: Hashable,
+ value: Scalar | AnyArrayLike,
+ allow_duplicates: bool | NoDefault = no_default,
+) -> None:
+ """
+ Insert column into ``DataFrame`` at specified location.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ raise_if_native_pandas_objects(value)
+ if allow_duplicates is no_default:
+ allow_duplicates = False
+ if not allow_duplicates and column in self.columns:
+ raise ValueError(f"cannot insert {column}, already exists")
+
+ if not isinstance(loc, int):
+ raise TypeError("loc must be int")
+
+ # If columns labels are multilevel, we implement following behavior (this is
+ # name native pandas):
+ # Case 1: if 'column' is tuple it's length must be same as number of levels
+ # otherwise raise error.
+ # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in
+ # empty strings to match the length of column levels in self frame.
+ if self.columns.nlevels > 1:
+ if isinstance(column, tuple) and len(column) != self.columns.nlevels:
+ # same error as native pandas.
+ raise ValueError("Item must have length equal to number of levels.")
+ if not isinstance(column, tuple):
+ # Fill empty strings to match length of levels
+ suffix = [""] * (self.columns.nlevels - 1)
+ column = tuple([column] + suffix)
+
+ # Dictionary keys are treated as index column and this should be joined with
+ # index of target dataframe. This behavior is similar to 'value' being DataFrame
+ # or Series, so we simply create Series from dict data here.
+ if isinstance(value, dict):
+ value = Series(value, name=column)
+
+ if isinstance(value, DataFrame) or (
+ isinstance(value, np.ndarray) and len(value.shape) > 1
+ ):
+ # Supported numpy array shapes are
+ # 1. (N, ) -> Ex. [1, 2, 3]
+ # 2. (N, 1) -> Ex> [[1], [2], [3]]
+ if value.shape[1] != 1:
+ if isinstance(value, DataFrame):
+ # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin
+ raise ValueError(
+ f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead."
+ )
+ else:
+ raise ValueError(
+ f"Expected a 1D array, got an array with shape {value.shape}"
+ )
+ # Change numpy array shape from (N, 1) to (N, )
+ if isinstance(value, np.ndarray):
+ value = value.squeeze(axis=1)
+
+ if (
+ is_list_like(value)
+ and not isinstance(value, (Series, DataFrame))
+ and len(value) != self.shape[0]
+ and not 0 == self.shape[0] # dataframe holds no rows
+ ):
+ raise ValueError(
+ "Length of values ({}) does not match length of index ({})".format(
+ len(value), len(self)
+ )
+ )
+ if not -len(self.columns) <= loc <= len(self.columns):
+ raise IndexError(
+ f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}"
+ )
+ elif loc < 0:
+ raise ValueError("unbounded slice")
+
+ join_on_index = False
+ if isinstance(value, (Series, DataFrame)):
+ value = value._query_compiler
+ join_on_index = True
+ elif is_list_like(value):
+ value = Series(value, name=column)._query_compiler
+
+ new_query_compiler = self._query_compiler.insert(loc, column, value, join_on_index)
+ # In pandas, 'insert' operation is always inplace.
+ self._update_inplace(new_query_compiler=new_query_compiler)
+
+
+# Snowpark pandas does more specialization based on the type of `values`
+@register_dataframe_accessor("isin")
+def isin(
+ self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike]
+) -> DataFrame:
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if isinstance(values, dict):
+ return super(DataFrame, self).isin(values)
+ elif isinstance(values, Series):
+ # Note: pandas performs explicit is_unique check here, deactivated for performance reasons.
+ # if not values.index.is_unique:
+ # raise ValueError("cannot compute isin with a duplicate axis.")
+ return self.__constructor__(
+ query_compiler=self._query_compiler.isin(values._query_compiler)
+ )
+ elif isinstance(values, DataFrame):
+ # Note: pandas performs explicit is_unique check here, deactivated for performance reasons.
+ # if not (values.columns.is_unique and values.index.is_unique):
+ # raise ValueError("cannot compute isin with a duplicate axis.")
+ return self.__constructor__(
+ query_compiler=self._query_compiler.isin(values._query_compiler)
+ )
+ else:
+ if not is_list_like(values):
+ # throw pandas compatible error
+ raise TypeError(
+ "only list-like or dict-like objects are allowed "
+ f"to be passed to {self.__class__.__name__}.isin(), "
+ f"you passed a '{type(values).__name__}'"
+ )
+ return super(DataFrame, self).isin(values)
+
+
+# Upstream Modin defaults to pandas for some arguments.
+@register_dataframe_accessor("join")
+def join(
+ self,
+ other: DataFrame | Series | Iterable[DataFrame | Series],
+ on: IndexLabel | None = None,
+ how: str = "left",
+ lsuffix: str = "",
+ rsuffix: str = "",
+ sort: bool = False,
+ validate: str | None = None,
+) -> DataFrame:
+ """
+ Join columns of another ``DataFrame``.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ for o in other if isinstance(other, list) else [other]:
+ raise_if_native_pandas_objects(o)
+
+ # Similar to native pandas we implement 'join' using 'pd.merge' method.
+ # Following code is copied from native pandas (with few changes explained below)
+ # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002
+ if isinstance(other, Series):
+ # Same error as native pandas.
+ if other.name is None:
+ raise ValueError("Other Series must have a name")
+ other = DataFrame(other)
+ elif is_list_like(other):
+ if any([isinstance(o, Series) and o.name is None for o in other]):
+ raise ValueError("Other Series must have a name")
+
+ if isinstance(other, DataFrame):
+ if how == "cross":
+ return pd.merge(
+ self,
+ other,
+ how=how,
+ on=on,
+ suffixes=(lsuffix, rsuffix),
+ sort=sort,
+ validate=validate,
+ )
+ return pd.merge(
+ self,
+ other,
+ left_on=on,
+ how=how,
+ left_index=on is None,
+ right_index=True,
+ suffixes=(lsuffix, rsuffix),
+ sort=sort,
+ validate=validate,
+ )
+ else: # List of DataFrame/Series
+ # Same error as native pandas.
+ if on is not None:
+ raise ValueError(
+ "Joining multiple DataFrames only supported for joining on index"
+ )
+
+ # Same error as native pandas.
+ if rsuffix or lsuffix:
+ raise ValueError("Suffixes not supported when joining multiple DataFrames")
+
+ # NOTE: These are not the differences between Snowpark pandas API and pandas behavior
+ # these are differences between native pandas join behavior when join
+ # frames have unique index or not.
+
+ # In native pandas logic to join multiple DataFrames/Series is data
+ # dependent. Under the hood it will either use 'concat' or 'merge' API
+ # Case 1. If all objects being joined have unique index use 'concat' (axis=1)
+ # Case 2. Otherwise use 'merge' API by looping through objects left to right.
+ # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046
+
+ # Even though concat (axis=1) and merge are very similar APIs they have
+ # some differences which leads to inconsistent behavior in native pandas.
+ # 1. Treatment of un-named Series
+ # Case #1: Un-named series is allowed in concat API. Objects are joined
+ # successfully by assigning a number as columns name (see 'concat' API
+ # documentation for details on treatment of un-named series).
+ # Case #2: It raises 'ValueError: Other Series must have a name'
+
+ # 2. how='right'
+ # Case #1: 'concat' API doesn't support right join. It raises
+ # 'ValueError: Only can inner (intersect) or outer (union) join the other axis'
+ # Case #2: Merges successfully.
+
+ # 3. Joining frames with duplicate labels but no conflict with other frames
+ # Example: self = DataFrame(... columns=["A", "B"])
+ # other = [DataFrame(... columns=["C", "C"])]
+ # Case #1: 'ValueError: Indexes have overlapping values'
+ # Case #2: Merged successfully.
-from snowflake.snowpark.modin.pandas.api.extensions import register_dataframe_accessor
-from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
- is_snowflake_agg_func,
-)
-from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
-from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
-from snowflake.snowpark.modin.utils import _inherit_docstrings, validate_int_kwarg
+ # In addition to this, native pandas implementation also leads to another
+ # type of inconsistency where left.join(other, ...) and
+ # left.join([other], ...) might behave differently for cases mentioned
+ # above.
+ # Example:
+ # import pandas as pd
+ # df = pd.DataFrame({"a": [4, 5]})
+ # other = pd.Series([1, 2])
+ # df.join([other]) # this is successful
+ # df.join(other) # this raises 'ValueError: Other Series must have a name'
+
+ # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API
+ # to join multiple DataFrame/Series. So always follow the behavior
+ # documented as Case #2 above.
+
+ joined = self
+ for frame in other:
+ if isinstance(frame, DataFrame):
+ overlapping_cols = set(joined.columns).intersection(set(frame.columns))
+ if len(overlapping_cols) > 0:
+ # Native pandas raises: 'Indexes have overlapping values'
+ # We differ slightly from native pandas message to make it more
+ # useful to users.
+ raise ValueError(
+ f"Join dataframes have overlapping column labels: {overlapping_cols}"
+ )
+ joined = pd.merge(
+ joined,
+ frame,
+ how=how,
+ left_index=True,
+ right_index=True,
+ validate=validate,
+ sort=sort,
+ suffixes=(None, None),
+ )
+ return joined
+
+
+# Snowpark pandas does extra error checking.
+@register_dataframe_accessor("mask")
+def mask(
+ self,
+ cond: DataFrame | Series | Callable | AnyArrayLike,
+ other: DataFrame | Series | Callable | Scalar | None = np.nan,
+ *,
+ inplace: bool = False,
+ axis: Axis | None = None,
+ level: Level | None = None,
+):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if isinstance(other, Series) and axis is None:
+ raise ValueError(
+ "df.mask requires an axis parameter (0 or 1) when given a Series"
+ )
+
+ return super(DataFrame, self).mask(
+ cond,
+ other=other,
+ inplace=inplace,
+ axis=axis,
+ level=level,
+ )
+
+
+# Snowpark pandas has a fix for a pandas behavior change. It is available in Modin 0.30.1 (SNOW-1552497).
+@register_dataframe_accessor("melt")
+def melt(
+ self,
+ id_vars=None,
+ value_vars=None,
+ var_name=None,
+ value_name="value",
+ col_level=None,
+ ignore_index=True,
+): # noqa: PR01, RT01, D200
+ """
+ Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if id_vars is None:
+ id_vars = []
+ if not is_list_like(id_vars):
+ id_vars = [id_vars]
+ if value_vars is None:
+ # Behavior of Index.difference changed in 2.2.x
+ # https://github.com/pandas-dev/pandas/pull/55113
+ # This change needs upstream to Modin:
+ # https://github.com/modin-project/modin/issues/7206
+ value_vars = self.columns.drop(id_vars)
+ if var_name is None:
+ columns_name = self._query_compiler.get_index_name(axis=1)
+ var_name = columns_name if columns_name is not None else "variable"
+ return self.__constructor__(
+ query_compiler=self._query_compiler.melt(
+ id_vars=id_vars,
+ value_vars=value_vars,
+ var_name=var_name,
+ value_name=value_name,
+ col_level=col_level,
+ ignore_index=ignore_index,
+ )
+ )
+
+
+# Snowpark pandas does more thorough error checking.
+@register_dataframe_accessor("merge")
+def merge(
+ self,
+ right: DataFrame | Series,
+ how: str = "inner",
+ on: IndexLabel | None = None,
+ left_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None,
+ right_on: Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike] | None = None,
+ left_index: bool = False,
+ right_index: bool = False,
+ sort: bool = False,
+ suffixes: Suffixes = ("_x", "_y"),
+ copy: bool = True,
+ indicator: bool = False,
+ validate: str | None = None,
+) -> DataFrame:
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ # Raise error if native pandas objects are passed.
+ raise_if_native_pandas_objects(right)
+
+ if isinstance(right, Series) and right.name is None:
+ raise ValueError("Cannot merge a Series without a name")
+ if not isinstance(right, (Series, DataFrame)):
+ raise TypeError(
+ f"Can only merge Series or DataFrame objects, a {type(right)} was passed"
+ )
+
+ if isinstance(right, Series):
+ right_column_nlevels = len(right.name) if isinstance(right.name, tuple) else 1
+ else:
+ right_column_nlevels = right.columns.nlevels
+ if self.columns.nlevels != right_column_nlevels:
+ # This is deprecated in native pandas. We raise explicit error for this.
+ raise ValueError(
+ "Can not merge objects with different column levels."
+ + f" ({self.columns.nlevels} levels on the left,"
+ + f" {right_column_nlevels} on the right)"
+ )
+
+ # Merge empty native pandas dataframes for error checking. Otherwise, it will
+ # require a lot of logic to be written. This takes care of raising errors for
+ # following scenarios:
+ # 1. Only 'left_index' is set to True.
+ # 2. Only 'right_index is set to True.
+ # 3. Only 'left_on' is provided.
+ # 4. Only 'right_on' is provided.
+ # 5. 'on' and 'left_on' both are provided
+ # 6. 'on' and 'right_on' both are provided
+ # 7. 'on' and 'left_index' both are provided
+ # 8. 'on' and 'right_index' both are provided
+ # 9. 'left_on' and 'left_index' both are provided
+ # 10. 'right_on' and 'right_index' both are provided
+ # 11. Length mismatch between 'left_on' and 'right_on'
+ # 12. 'left_index' is not a bool
+ # 13. 'right_index' is not a bool
+ # 14. 'on' is not None and how='cross'
+ # 15. 'left_on' is not None and how='cross'
+ # 16. 'right_on' is not None and how='cross'
+ # 17. 'left_index' is True and how='cross'
+ # 18. 'right_index' is True and how='cross'
+ # 19. Unknown label in 'on', 'left_on' or 'right_on'
+ # 20. Provided 'suffixes' is not sufficient to resolve conflicts.
+ # 21. Merging on column with duplicate labels.
+ # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'}
+ # 23. conflict with existing labels for array-like join key
+ # 24. 'indicator' argument is not bool or str
+ # 25. indicator column label conflicts with existing data labels
+ create_empty_native_pandas_frame(self).merge(
+ create_empty_native_pandas_frame(right),
+ on=on,
+ how=how,
+ left_on=replace_external_data_keys_with_empty_pandas_series(left_on),
+ right_on=replace_external_data_keys_with_empty_pandas_series(right_on),
+ left_index=left_index,
+ right_index=right_index,
+ suffixes=suffixes,
+ indicator=indicator,
+ )
+
+ return self.__constructor__(
+ query_compiler=self._query_compiler.merge(
+ right._query_compiler,
+ how=how,
+ on=on,
+ left_on=replace_external_data_keys_with_query_compiler(self, left_on),
+ right_on=replace_external_data_keys_with_query_compiler(right, right_on),
+ left_index=left_index,
+ right_index=right_index,
+ sort=sort,
+ suffixes=suffixes,
+ copy=copy,
+ indicator=indicator,
+ validate=validate,
+ )
+ )
@_inherit_docstrings(native_pd.DataFrame.memory_usage, apilink="pandas.DataFrame")
@@ -62,6 +1485,125 @@ def memory_usage(self, index: bool = True, deep: bool = False) -> Any:
return native_pd.Series([0] * len(columns), index=columns)
+# Snowpark pandas handles `inplace` differently.
+@register_dataframe_accessor("replace")
+def replace(
+ self,
+ to_replace=None,
+ value=no_default,
+ inplace: bool = False,
+ limit=None,
+ regex: bool = False,
+ method: str | NoDefault = no_default,
+):
+ """
+ Replace values given in `to_replace` with `value`.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ new_query_compiler = self._query_compiler.replace(
+ to_replace=to_replace,
+ value=value,
+ limit=limit,
+ regex=regex,
+ method=method,
+ )
+ return self._create_or_update_from_compiler(new_query_compiler, inplace)
+
+
+# Snowpark pandas interacts with the inplace flag differently.
+@register_dataframe_accessor("rename")
+def rename(
+ self,
+ mapper: Renamer | None = None,
+ *,
+ index: Renamer | None = None,
+ columns: Renamer | None = None,
+ axis: Axis | None = None,
+ copy: bool | None = None,
+ inplace: bool = False,
+ level: Level | None = None,
+ errors: IgnoreRaise = "ignore",
+) -> DataFrame | None:
+ """
+ Alter axes labels.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ if mapper is None and index is None and columns is None:
+ raise TypeError("must pass an index to rename")
+
+ if index is not None or columns is not None:
+ if axis is not None:
+ raise TypeError(
+ "Cannot specify both 'axis' and any of 'index' or 'columns'"
+ )
+ elif mapper is not None:
+ raise TypeError(
+ "Cannot specify both 'mapper' and any of 'index' or 'columns'"
+ )
+ else:
+ # use the mapper argument
+ if axis and self._get_axis_number(axis) == 1:
+ columns = mapper
+ else:
+ index = mapper
+
+ if copy is not None:
+ WarningMessage.ignored_argument(
+ operation="dataframe.rename",
+ argument="copy",
+ message="copy parameter has been ignored with Snowflake execution engine",
+ )
+
+ if isinstance(index, dict):
+ index = Series(index)
+
+ new_qc = self._query_compiler.rename(
+ index_renamer=index, columns_renamer=columns, level=level, errors=errors
+ )
+ return self._create_or_update_from_compiler(
+ new_query_compiler=new_qc, inplace=inplace
+ )
+
+
+# Upstream modin converts aggfunc to a cython function if it's a string.
+@register_dataframe_accessor("pivot_table")
+def pivot_table(
+ self,
+ values=None,
+ index=None,
+ columns=None,
+ aggfunc="mean",
+ fill_value=None,
+ margins=False,
+ dropna=True,
+ margins_name="All",
+ observed=False,
+ sort=True,
+):
+ """
+ Create a spreadsheet-style pivot table as a ``DataFrame``.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ result = self.__constructor__(
+ query_compiler=self._query_compiler.pivot_table(
+ index=index,
+ values=values,
+ columns=columns,
+ aggfunc=aggfunc,
+ fill_value=fill_value,
+ margins=margins,
+ dropna=dropna,
+ margins_name=margins_name,
+ observed=observed,
+ sort=sort,
+ )
+ )
+ return result
+
+
+# Snowpark pandas produces a different warning for materialization.
@register_dataframe_accessor("plot")
@property
def plot(
@@ -108,11 +1650,227 @@ def plot(
return self._to_pandas().plot
+# Upstream Modin defaults when other is a Series.
+@register_dataframe_accessor("pow")
+def pow(
+ self, other, axis="columns", level=None, fill_value=None
+): # noqa: PR01, RT01, D200
+ """
+ Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`).
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._binary_op(
+ "pow",
+ other,
+ axis=axis,
+ level=level,
+ fill_value=fill_value,
+ )
+
+
+@register_dataframe_accessor("rpow")
+def rpow(
+ self, other, axis="columns", level=None, fill_value=None
+): # noqa: PR01, RT01, D200
+ """
+ Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`).
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._binary_op(
+ "rpow",
+ other,
+ axis=axis,
+ level=level,
+ fill_value=fill_value,
+ )
+
+
+# Snowpark pandas does extra argument validation, and uses iloc instead of drop at the end.
+@register_dataframe_accessor("select_dtypes")
+def select_dtypes(
+ self,
+ include: ListLike | str | type | None = None,
+ exclude: ListLike | str | type | None = None,
+) -> DataFrame:
+ """
+ Return a subset of the ``DataFrame``'s columns based on the column dtypes.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ # This line defers argument validation to pandas, which will raise errors on our behalf in cases
+ # like if `include` and `exclude` are None, the same type is specified in both lists, or a string
+ # dtype (as opposed to object) is specified.
+ native_pd.DataFrame().select_dtypes(include, exclude)
+
+ if include and not is_list_like(include):
+ include = [include]
+ elif include is None:
+ include = []
+ if exclude and not is_list_like(exclude):
+ exclude = [exclude]
+ elif exclude is None:
+ exclude = []
+
+ sel = tuple(map(set, (include, exclude)))
+
+ # The width of the np.int_/float_ alias differs between Windows and other platforms, so
+ # we need to include a workaround.
+ # https://github.com/numpy/numpy/issues/9464
+ # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036
+ def check_sized_number_infer_dtypes(dtype):
+ if (isinstance(dtype, str) and dtype == "int") or (dtype is int):
+ return [np.int32, np.int64]
+ elif dtype == "float" or dtype is float:
+ return [np.float64, np.float32]
+ else:
+ return [infer_dtype_from_object(dtype)]
+
+ include, exclude = map(
+ lambda x: set(
+ itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x))
+ ),
+ sel,
+ )
+ # We need to index on column position rather than label in case of duplicates
+ include_these = native_pd.Series(not bool(include), index=range(len(self.columns)))
+ exclude_these = native_pd.Series(not bool(exclude), index=range(len(self.columns)))
+
+ def is_dtype_instance_mapper(dtype):
+ return functools.partial(issubclass, dtype.type)
+
+ for i, dtype in enumerate(self.dtypes):
+ if include:
+ include_these[i] = any(map(is_dtype_instance_mapper(dtype), include))
+ if exclude:
+ exclude_these[i] = not any(map(is_dtype_instance_mapper(dtype), exclude))
+
+ dtype_indexer = include_these & exclude_these
+ indicate = [i for i, should_keep in dtype_indexer.items() if should_keep]
+ # We need to use iloc instead of drop in case of duplicate column names
+ return self.iloc[:, indicate]
+
+
+# Snowpark pandas does extra validation on the `axis` argument.
+@register_dataframe_accessor("set_axis")
+def set_axis(
+ self,
+ labels: IndexLabel,
+ *,
+ axis: Axis = 0,
+ copy: bool | NoDefault = no_default, # ignored
+):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if not is_scalar(axis):
+ raise TypeError(f"{type(axis).__name__} is not a valid type for axis.")
+ return super(DataFrame, self).set_axis(
+ labels=labels,
+ # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df.
+ axis=native_pd.DataFrame._get_axis_name(axis),
+ copy=copy,
+ )
+
+
+# Snowpark pandas needs extra logic for the lazy index class.
+@register_dataframe_accessor("set_index")
+def set_index(
+ self,
+ keys: IndexLabel
+ | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable],
+ drop: bool = True,
+ append: bool = False,
+ inplace: bool = False,
+ verify_integrity: bool = False,
+) -> None | DataFrame:
+ """
+ Set the ``DataFrame`` index using existing columns.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ if not isinstance(keys, list):
+ keys = [keys]
+
+ # make sure key is either hashable, index, or series
+ label_or_series = []
+
+ missing = []
+ columns = self.columns.tolist()
+ for key in keys:
+ raise_if_native_pandas_objects(key)
+ if isinstance(key, pd.Series):
+ label_or_series.append(key._query_compiler)
+ elif isinstance(key, (np.ndarray, list, Iterator)):
+ label_or_series.append(pd.Series(key)._query_compiler)
+ elif isinstance(key, (pd.Index, native_pd.MultiIndex)):
+ label_or_series += [s._query_compiler for s in self._to_series_list(key)]
+ else:
+ if not is_hashable(key):
+ raise TypeError(
+ f'The parameter "keys" may be a column key, one-dimensional array, or a list '
+ f"containing only valid column keys and one-dimensional arrays. Received column "
+ f"of type {type(key)}"
+ )
+ label_or_series.append(key)
+ found = key in columns
+ if columns.count(key) > 1:
+ raise ValueError(f"The column label '{key}' is not unique")
+ elif not found:
+ missing.append(key)
+
+ if missing:
+ raise KeyError(f"None of {missing} are in the columns")
+
+ new_query_compiler = self._query_compiler.set_index(
+ label_or_series, drop=drop, append=append
+ )
+
+ # TODO: SNOW-782633 improve this code once duplicate is supported
+ # this needs to pull all index which is inefficient
+ if verify_integrity and not new_query_compiler.index.is_unique:
+ duplicates = new_query_compiler.index[
+ new_query_compiler.index.to_pandas().duplicated()
+ ].unique()
+ raise ValueError(f"Index has duplicate keys: {duplicates}")
+
+ return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace)
+
+
+# Upstream Modin uses `len(self.index)` instead of `len(self)`, which gives an extra query.
+@register_dataframe_accessor("shape")
+@property
+def shape(self) -> tuple[int, int]:
+ """
+ Return a tuple representing the dimensionality of the ``DataFrame``.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return len(self), len(self.columns)
+
+
+# Snowpark pands has rewrites to minimize queries from length checks.
+@register_dataframe_accessor("squeeze")
+def squeeze(self, axis: Axis | None = None):
+ """
+ Squeeze 1 dimensional axis objects into scalars.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ axis = self._get_axis_number(axis) if axis is not None else None
+ len_columns = self._query_compiler.get_axis_len(1)
+ if axis == 1 and len_columns == 1:
+ return Series(query_compiler=self._query_compiler)
+ if axis in [0, None]:
+ # get_axis_len(0) results in a sql query to count number of rows in current
+ # dataframe. We should only compute len_index if axis is 0 or None.
+ len_index = len(self)
+ if axis is None and (len_columns == 1 or len_index == 1):
+ return Series(query_compiler=self._query_compiler).squeeze()
+ if axis == 0 and len_index == 1:
+ return Series(query_compiler=self.T._query_compiler)
+ return self.copy()
+
+
# Upstream modin defines sum differently for series/DF, but we use the same implementation for both.
@register_dataframe_accessor("sum")
def sum(
self,
- axis: Union[Axis, None] = None,
+ axis: Axis | None = None,
skipna: bool = True,
numeric_only: bool = False,
min_count: int = 0,
@@ -130,6 +1888,70 @@ def sum(
)
+# Snowpark pandas raises a warning where modin defaults to pandas.
+@register_dataframe_accessor("stack")
+def stack(
+ self,
+ level: int | str | list = -1,
+ dropna: bool | NoDefault = no_default,
+ sort: bool | NoDefault = no_default,
+ future_stack: bool = False, # ignored
+):
+ """
+ Stack the prescribed level(s) from columns to index.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if future_stack is not False:
+ WarningMessage.ignored_argument( # pragma: no cover
+ operation="DataFrame.stack",
+ argument="future_stack",
+ message="future_stack parameter has been ignored with Snowflake execution engine",
+ )
+ if dropna is NoDefault:
+ dropna = True # pragma: no cover
+ if sort is NoDefault:
+ sort = True # pragma: no cover
+
+ # This ensures that non-pandas MultiIndex objects are caught.
+ is_multiindex = len(self.columns.names) > 1
+ if not is_multiindex or (
+ is_multiindex and is_list_like(level) and len(level) == self.columns.nlevels
+ ):
+ return self._reduce_dimension(
+ query_compiler=self._query_compiler.stack(level, dropna, sort)
+ )
+ else:
+ return self.__constructor__(
+ query_compiler=self._query_compiler.stack(level, dropna, sort)
+ )
+
+
+# Upstream modin doesn't pass `copy`, so we can't raise a warning for it.
+# No need to override the `T` property since that can't take any extra arguments.
+@register_dataframe_accessor("transpose")
+def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200
+ """
+ Transpose index and columns.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if copy:
+ WarningMessage.ignored_argument(
+ operation="transpose",
+ argument="copy",
+ message="Transpose ignore copy argument in Snowpark pandas API",
+ )
+
+ if args:
+ WarningMessage.ignored_argument(
+ operation="transpose",
+ argument="args",
+ message="Transpose ignores args in Snowpark pandas API",
+ )
+
+ return self.__constructor__(query_compiler=self._query_compiler.transpose())
+
+
+# Upstream modin implements transform in base.py, but we don't yet support Series.transform.
@register_dataframe_accessor("transform")
def transform(
self, func: PythonFuncType, axis: Axis = 0, *args: Any, **kwargs: Any
@@ -151,3 +1973,380 @@ def transform(
raise ValueError("Function did not transform")
return self.apply(func, axis, False, args=args, **kwargs)
+
+
+# Upstream modin defaults to pandas for some arguments.
+@register_dataframe_accessor("unstack")
+def unstack(
+ self,
+ level: int | str | list = -1,
+ fill_value: int | str | dict = None,
+ sort: bool = True,
+):
+ """
+ Pivot a level of the (necessarily hierarchical) index labels.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ # This ensures that non-pandas MultiIndex objects are caught.
+ nlevels = self._query_compiler.nlevels()
+ is_multiindex = nlevels > 1
+
+ if not is_multiindex or (
+ is_multiindex and is_list_like(level) and len(level) == nlevels
+ ):
+ return self._reduce_dimension(
+ query_compiler=self._query_compiler.unstack(
+ level, fill_value, sort, is_series_input=False
+ )
+ )
+ else:
+ return self.__constructor__(
+ query_compiler=self._query_compiler.unstack(
+ level, fill_value, sort, is_series_input=False
+ )
+ )
+
+
+# Upstream modin does different validation and sorting.
+@register_dataframe_accessor("value_counts")
+def value_counts(
+ self,
+ subset: Sequence[Hashable] | None = None,
+ normalize: bool = False,
+ sort: bool = True,
+ ascending: bool = False,
+ dropna: bool = True,
+):
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return Series(
+ query_compiler=self._query_compiler.value_counts(
+ subset=subset,
+ normalize=normalize,
+ sort=sort,
+ ascending=ascending,
+ dropna=dropna,
+ ),
+ name="proportion" if normalize else "count",
+ )
+
+
+@register_dataframe_accessor("where")
+def where(
+ self,
+ cond: DataFrame | Series | Callable | AnyArrayLike,
+ other: DataFrame | Series | Callable | Scalar | None = np.nan,
+ *,
+ inplace: bool = False,
+ axis: Axis | None = None,
+ level: Level | None = None,
+):
+ """
+ Replace values where the condition is False.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ if isinstance(other, Series) and axis is None:
+ raise ValueError(
+ "df.where requires an axis parameter (0 or 1) when given a Series"
+ )
+
+ return super(DataFrame, self).where(
+ cond,
+ other=other,
+ inplace=inplace,
+ axis=axis,
+ level=level,
+ )
+
+
+# Snowpark pandas has a custom iterator.
+@register_dataframe_accessor("iterrows")
+def iterrows(self) -> Iterator[tuple[Hashable, Series]]:
+ """
+ Iterate over ``DataFrame`` rows as (index, ``Series``) pairs.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ def iterrow_builder(s):
+ """Return tuple of the given `s` parameter name and the parameter themselves."""
+ return s.name, s
+
+ # Raise warning message since iterrows is very inefficient.
+ WarningMessage.single_warning(
+ DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows")
+ )
+
+ partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder)
+ yield from partition_iterator
+
+
+# Snowpark pandas has a custom iterator.
+@register_dataframe_accessor("itertuples")
+def itertuples(
+ self, index: bool = True, name: str | None = "Pandas"
+) -> Iterable[tuple[Any, ...]]:
+ """
+ Iterate over ``DataFrame`` rows as ``namedtuple``-s.
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+
+ def itertuples_builder(s):
+ """Return the next namedtuple."""
+ # s is the Series of values in the current row.
+ fields = [] # column names
+ data = [] # values under each column
+
+ if index:
+ data.append(s.name)
+ fields.append("Index")
+
+ # Fill column names and values.
+ fields.extend(list(self.columns))
+ data.extend(s)
+
+ if name is not None:
+ # Creating the namedtuple.
+ itertuple = collections.namedtuple(name, fields, rename=True)
+ return itertuple._make(data)
+
+ # When the name is None, return a regular tuple.
+ return tuple(data)
+
+ # Raise warning message since itertuples is very inefficient.
+ WarningMessage.single_warning(
+ DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples")
+ )
+ return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True)
+
+
+# Snowpark pandas truncates the repr output.
+@register_dataframe_accessor("__repr__")
+def __repr__(self):
+ """
+ Return a string representation for a particular ``DataFrame``.
+
+ Returns
+ -------
+ str
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ num_rows = native_pd.get_option("display.max_rows") or len(self)
+ # see _repr_html_ for comment, allow here also all column behavior
+ num_cols = native_pd.get_option("display.max_columns") or len(self.columns)
+
+ (
+ row_count,
+ col_count,
+ repr_df,
+ ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x")
+ result = repr(repr_df)
+
+ # if truncated, add shape information
+ if is_repr_truncated(row_count, col_count, num_rows, num_cols):
+ # The split here is so that we don't repr pandas row lengths.
+ return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format(
+ row_count, col_count
+ )
+ else:
+ return result
+
+
+# Snowpark pandas uses a different default `num_rows` value.
+@register_dataframe_accessor("_repr_html_")
+def _repr_html_(self): # pragma: no cover
+ """
+ Return a html representation for a particular ``DataFrame``.
+
+ Returns
+ -------
+ str
+
+ Notes
+ -----
+ Supports pandas `display.max_rows` and `display.max_columns` options.
+ """
+ num_rows = native_pd.get_option("display.max_rows") or 60
+ # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow
+ # here value=0 which means display all columns.
+ num_cols = native_pd.get_option("display.max_columns")
+
+ (
+ row_count,
+ col_count,
+ repr_df,
+ ) = self._query_compiler.build_repr_df(num_rows, num_cols)
+ result = repr_df._repr_html_()
+
+ if is_repr_truncated(row_count, col_count, num_rows, num_cols):
+ # We split so that we insert our correct dataframe dimensions.
+ return (
+ result.split("")[0]
+ + f"
{row_count} rows × {col_count} columns
\n"
+ )
+ else:
+ return result
+
+
+# Upstream modin just uses `to_datetime` rather than `dataframe_to_datetime` on the query compiler.
+@register_dataframe_accessor("_to_datetime")
+def _to_datetime(self, **kwargs):
+ """
+ Convert `self` to datetime.
+
+ Parameters
+ ----------
+ **kwargs : dict
+ Optional arguments to use during query compiler's
+ `to_datetime` invocation.
+
+ Returns
+ -------
+ Series of datetime64 dtype
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._reduce_dimension(
+ query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs)
+ )
+
+
+# Snowpark pandas has the extra `statement_params` argument.
+@register_dataframe_accessor("_to_pandas")
+def _to_pandas(
+ self,
+ *,
+ statement_params: dict[str, str] | None = None,
+ **kwargs: Any,
+) -> native_pd.DataFrame:
+ """
+ Convert Snowpark pandas DataFrame to pandas DataFrame
+
+ Args:
+ statement_params: Dictionary of statement level parameters to be set while executing this action.
+
+ Returns:
+ pandas DataFrame
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ return self._query_compiler.to_pandas(statement_params=statement_params, **kwargs)
+
+
+# Snowpark pandas does more validation and error checking than upstream Modin, and uses different
+# helper methods for dispatch.
+@register_dataframe_accessor("__setitem__")
+def __setitem__(self, key: Any, value: Any):
+ """
+ Set attribute `value` identified by `key`.
+
+ Args:
+ key: Key to set
+ value: Value to set
+
+ Note:
+ In the case where value is any list like or array, pandas checks the array length against the number of rows
+ of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw
+ a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if
+ the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use
+ enlargement filling with the last value in the array.
+
+ Returns:
+ None
+ """
+ # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
+ key = apply_if_callable(key, self)
+ if isinstance(key, DataFrame) or (
+ isinstance(key, np.ndarray) and len(key.shape) == 2
+ ):
+ # This case uses mask's codepath to perform the set, but
+ # we need to duplicate the code here since we are passing
+ # an additional kwarg `cond_fillna_with_true` to the QC here.
+ # We need this additional kwarg, since if df.shape
+ # and key.shape do not align (i.e. df has more rows),
+ # mask's codepath would mask the additional rows in df
+ # while for setitem, we need to keep the original values.
+ if not isinstance(key, DataFrame):
+ if key.dtype != bool:
+ raise TypeError(
+ "Must pass DataFrame or 2-d ndarray with boolean values only"
+ )
+ key = DataFrame(key)
+ key._query_compiler._shape_hint = "array"
+
+ if value is not None:
+ value = apply_if_callable(value, self)
+
+ if isinstance(value, np.ndarray):
+ value = DataFrame(value)
+ value._query_compiler._shape_hint = "array"
+ elif isinstance(value, pd.Series):
+ # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this
+ # error instead, since it is more descriptive.
+ raise ValueError(
+ "setitem with a 2D key does not support Series values."
+ )
+
+ if isinstance(value, BasePandasDataset):
+ value = value._query_compiler
+
+ query_compiler = self._query_compiler.mask(
+ cond=key._query_compiler,
+ other=value,
+ axis=None,
+ level=None,
+ cond_fillna_with_true=True,
+ )
+
+ return self._create_or_update_from_compiler(query_compiler, inplace=True)
+
+ # Error Checking:
+ if (isinstance(key, pd.Series) or is_list_like(key)) and (isinstance(value, range)):
+ raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE)
+ elif isinstance(value, slice):
+ # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value.
+ raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE)
+
+ # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column
+ # key.
+ index, columns = slice(None), key
+ index_is_bool_indexer = False
+ if isinstance(key, slice):
+ if is_integer(key.start) and is_integer(key.stop):
+ # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as
+ # df.iloc[1:2, :] = val
+ self.iloc[key] = value
+ return
+ index, columns = key, slice(None)
+ elif isinstance(key, pd.Series):
+ if is_bool_dtype(key.dtype):
+ index, columns = key, slice(None)
+ index_is_bool_indexer = True
+ elif is_bool_indexer(key):
+ index, columns = pd.Series(key), slice(None)
+ index_is_bool_indexer = True
+
+ # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case
+ # we have to explicitly set matching_item_columns_by_label to False for setitem.
+ index = index._query_compiler if isinstance(index, BasePandasDataset) else index
+ columns = (
+ columns._query_compiler if isinstance(columns, BasePandasDataset) else columns
+ )
+ from snowflake.snowpark.modin.pandas.indexing import is_2d_array
+
+ matching_item_rows_by_label = not is_2d_array(value)
+ if is_2d_array(value):
+ value = DataFrame(value)
+ item = value._query_compiler if isinstance(value, BasePandasDataset) else value
+ new_qc = self._query_compiler.set_2d_labels(
+ index,
+ columns,
+ item,
+ # setitem always matches item by position
+ matching_item_columns_by_label=False,
+ matching_item_rows_by_label=matching_item_rows_by_label,
+ index_is_bool_indexer=index_is_bool_indexer,
+ # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling
+ # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the
+ # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have
+ # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns
+ # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B",
+ # "X", "X".
+ deduplicate_columns=True,
+ )
+ return self._update_inplace(new_query_compiler=new_qc)
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py
index 7be7adb54c1..38edb9f7bee 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py
@@ -229,7 +229,7 @@ def __init__(
--------
>>> idx = pd.DatetimeIndex(["1/1/2020 10:00:00+00:00", "2/1/2020 11:00:00+00:00"], tz="America/Los_Angeles")
>>> idx
- DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
+ DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, UTC-08:00]', freq=None)
"""
# DatetimeIndex is already initialized in __new__ method. We keep this method
# only for docstring generation.
@@ -960,7 +960,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex:
DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None)
"""
- @datetime_index_not_implemented()
def tz_convert(self, tz) -> DatetimeIndex:
"""
Convert tz-aware Datetime Array/Index from one time zone to another.
@@ -1025,8 +1024,14 @@ def tz_convert(self, tz) -> DatetimeIndex:
'2014-08-01 09:00:00'],
dtype='datetime64[ns]', freq='h')
"""
+ # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests.
+ return DatetimeIndex(
+ query_compiler=self._query_compiler.dt_tz_convert(
+ tz,
+ include_index=True,
+ )
+ )
- @datetime_index_not_implemented()
def tz_localize(
self,
tz,
@@ -1104,21 +1109,29 @@ def tz_localize(
Localize DatetimeIndex in US/Eastern time zone:
- >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP
- >>> tz_aware # doctest: +SKIP
- DatetimeIndex(['2018-03-01 09:00:00-05:00',
- '2018-03-02 09:00:00-05:00',
+ >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern')
+ >>> tz_aware
+ DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00',
'2018-03-03 09:00:00-05:00'],
- dtype='datetime64[ns, US/Eastern]', freq=None)
+ dtype='datetime64[ns, UTC-05:00]', freq=None)
With the ``tz=None``, we can remove the time zone information
while keeping the local time (not converted to UTC):
- >>> tz_aware.tz_localize(None) # doctest: +SKIP
+ >>> tz_aware.tz_localize(None)
DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00',
'2018-03-03 09:00:00'],
dtype='datetime64[ns]', freq=None)
"""
+ # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests.
+ return DatetimeIndex(
+ query_compiler=self._query_compiler.dt_tz_localize(
+ tz,
+ ambiguous,
+ nonexistent,
+ include_index=True,
+ )
+ )
def round(
self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise"
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py
index 12710224de7..b25bb481dc0 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/index.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py
@@ -30,6 +30,7 @@
import modin
import numpy as np
import pandas as native_pd
+from modin.pandas import DataFrame, Series
from modin.pandas.base import BasePandasDataset
from pandas import get_option
from pandas._libs import lib
@@ -49,7 +50,6 @@
)
from pandas.core.dtypes.inference import is_hashable
-from snowflake.snowpark.modin.pandas import DataFrame, Series
from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta
from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin
@@ -71,6 +71,35 @@
}
+class IndexParent:
+ def __init__(self, parent: DataFrame | Series) -> None:
+ """
+ Initialize the IndexParent object.
+
+ IndexParent is used to keep track of the parent object that the Index is a part of.
+ It tracks the parent object and the parent object's query compiler at the time of creation.
+
+ Parameters
+ ----------
+ parent : DataFrame or Series
+ The parent object that the Index is a part of.
+ """
+ assert isinstance(parent, (DataFrame, Series))
+ self._parent = parent
+ self._parent_qc = parent._query_compiler
+
+ def check_and_update_parent_qc_index_names(self, names: list) -> None:
+ """
+ Update the Index and its parent's index names if the query compiler associated with the parent is
+ different from the original query compiler recorded, i.e., an inplace update has been applied to the parent.
+ """
+ if self._parent._query_compiler is self._parent_qc:
+ new_query_compiler = self._parent_qc.set_index_names(names)
+ self._parent._update_inplace(new_query_compiler=new_query_compiler)
+ # Update the query compiler after naming operation.
+ self._parent_qc = new_query_compiler
+
+
class Index(metaclass=TelemetryMeta):
# Equivalent index type in native pandas
@@ -135,7 +164,7 @@ def __new__(
index = object.__new__(cls)
# Initialize the Index
index._query_compiler = query_compiler
- # `_parent` keeps track of any Series or DataFrame that this Index is a part of.
+ # `_parent` keeps track of the parent object that this Index is a part of.
index._parent = None
return index
@@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any:
ErrorMessage.not_implemented(f"Index.{key} is not yet implemented")
raise err
+ def _set_parent(self, parent: Series | DataFrame) -> None:
+ """
+ Set the parent object and its query compiler.
+
+ Parameters
+ ----------
+ parent : Series or DataFrame
+ The parent object that the Index is a part of.
+ """
+ self._parent = IndexParent(parent)
+
def _binary_ops(self, method: str, other: Any) -> Index:
if isinstance(other, Index):
other = other.to_series().reset_index(drop=True)
@@ -408,12 +448,6 @@ def __constructor__(self):
"""
return type(self)
- def _set_parent(self, parent: Series | DataFrame):
- """
- Set the parent object of the current Index to a given Series or DataFrame.
- """
- self._parent = parent
-
@property
def values(self) -> ArrayLike:
"""
@@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None:
if not is_hashable(value):
raise TypeError(f"{type(self).__name__}.name must be a hashable type")
self._query_compiler = self._query_compiler.set_index_names([value])
+ # Update the name of the parent's index only if an inplace update is performed on
+ # the parent object, i.e., the parent's current query compiler matches the originally
+ # recorded query compiler.
if self._parent is not None:
- self._parent._update_inplace(
- new_query_compiler=self._parent._query_compiler.set_index_names([value])
- )
+ self._parent.check_and_update_parent_qc_index_names([value])
def _get_names(self) -> list[Hashable]:
"""
@@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None:
if isinstance(values, Index):
values = values.to_list()
self._query_compiler = self._query_compiler.set_index_names(values)
+ # Update the name of the parent's index only if the parent's current query compiler
+ # matches the recorded query compiler.
if self._parent is not None:
- self._parent._update_inplace(
- new_query_compiler=self._parent._query_compiler.set_index_names(values)
- )
+ self._parent.check_and_update_parent_qc_index_names(values)
names = property(fset=_set_names, fget=_get_names)
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py
index c5f9e4f6cee..43f9603cfb4 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py
@@ -9,10 +9,10 @@
import inspect
from typing import Any, Iterable, Literal, Optional, Union
+from modin.pandas import DataFrame, Series
from pandas._typing import IndexLabel
from snowflake.snowpark import DataFrame as SnowparkDataFrame
-from snowflake.snowpark.modin.pandas import DataFrame, Series
from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor
from snowflake.snowpark.modin.plugin._internal.telemetry import (
snowpark_pandas_telemetry_standalone_function_decorator,
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py
index dea98bbb0d3..6d6fb4cd0bd 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py
@@ -15,6 +15,7 @@
)
import pandas as native_pd
+from modin.pandas import DataFrame
from pandas._libs.lib import NoDefault, no_default
from pandas._typing import (
CSVEngine,
@@ -26,7 +27,6 @@
)
import snowflake.snowpark.modin.pandas as pd
-from snowflake.snowpark.modin.pandas import DataFrame
from snowflake.snowpark.modin.pandas.api.extensions import register_pd_accessor
from snowflake.snowpark.modin.plugin._internal.telemetry import (
snowpark_pandas_telemetry_standalone_function_decorator,
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py
index f7bba4c743a..5b245bfdab4 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py
@@ -181,7 +181,6 @@ def to_pandas(
See Also:
- :func:`to_pandas `
- - :func:`DataFrame.to_pandas `
Returns:
pandas Series
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
index 5011defa685..b104c223e26 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
@@ -9,22 +9,13 @@
from __future__ import annotations
-from typing import (
- IO,
- TYPE_CHECKING,
- Any,
- Callable,
- Hashable,
- Literal,
- Mapping,
- Sequence,
-)
+from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence
import modin.pandas as pd
import numpy as np
import numpy.typing as npt
import pandas as native_pd
-from modin.pandas import Series
+from modin.pandas import DataFrame, Series
from modin.pandas.base import BasePandasDataset
from pandas._libs.lib import NoDefault, is_integer, no_default
from pandas._typing import (
@@ -73,9 +64,6 @@
validate_int_kwarg,
)
-if TYPE_CHECKING:
- from modin.pandas import DataFrame
-
def register_series_not_implemented():
def decorator(base_method: Any):
@@ -209,21 +197,6 @@ def hist(
pass # pragma: no cover
-@register_series_not_implemented()
-def interpolate(
- self,
- method="linear",
- axis=0,
- limit=None,
- inplace=False,
- limit_direction: str | None = None,
- limit_area=None,
- downcast=None,
- **kwargs,
-): # noqa: PR01, RT01, D200
- pass # pragma: no cover
-
-
@register_series_not_implemented()
def item(self): # noqa: RT01, D200
pass # pragma: no cover
@@ -1451,9 +1424,7 @@ def set_axis(
)
-# TODO: SNOW-1063346
-# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once
-# our vendored copy of DataFrame is removed.
+# Snowpark pandas does different validation.
@register_series_accessor("rename")
def rename(
self,
@@ -1503,9 +1474,36 @@ def rename(
return self_cp
-# TODO: SNOW-1063346
-# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once
-# our vendored copy of DataFrame is removed.
+# Modin defaults to pandas for some arguments for unstack
+@register_series_accessor("unstack")
+def unstack(
+ self,
+ level: int | str | list = -1,
+ fill_value: int | str | dict = None,
+ sort: bool = True,
+):
+ """
+ Unstack, also known as pivot, Series with MultiIndex to produce DataFrame.
+ """
+ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
+ from modin.pandas.dataframe import DataFrame
+
+ # We can't unstack a Series object, if we don't have a MultiIndex.
+ if self._query_compiler.has_multiindex:
+ result = DataFrame(
+ query_compiler=self._query_compiler.unstack(
+ level, fill_value, sort, is_series_input=True
+ )
+ )
+ else:
+ raise ValueError( # pragma: no cover
+ f"index must be a MultiIndex to unstack, {type(self.index)} was passed"
+ )
+
+ return result
+
+
+# Snowpark pandas does an extra check on `len(ascending)`.
@register_series_accessor("sort_values")
def sort_values(
self,
@@ -1521,7 +1519,7 @@ def sort_values(
Sort by the values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
- from snowflake.snowpark.modin.pandas.dataframe import DataFrame
+ from modin.pandas.dataframe import DataFrame
if is_list_like(ascending) and len(ascending) != 1:
raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series")
@@ -1550,38 +1548,6 @@ def sort_values(
return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace)
-# TODO: SNOW-1063346
-# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once
-# our vendored copy of DataFrame is removed.
-# Modin also defaults to pandas for some arguments for unstack
-@register_series_accessor("unstack")
-def unstack(
- self,
- level: int | str | list = -1,
- fill_value: int | str | dict = None,
- sort: bool = True,
-):
- """
- Unstack, also known as pivot, Series with MultiIndex to produce DataFrame.
- """
- # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
- from snowflake.snowpark.modin.pandas.dataframe import DataFrame
-
- # We can't unstack a Series object, if we don't have a MultiIndex.
- if self._query_compiler.has_multiindex:
- result = DataFrame(
- query_compiler=self._query_compiler.unstack(
- level, fill_value, sort, is_series_input=True
- )
- )
- else:
- raise ValueError( # pragma: no cover
- f"index must be a MultiIndex to unstack, {type(self.index)} was passed"
- )
-
- return result
-
-
# Upstream Modin defaults at the frontend layer.
@register_series_accessor("where")
def where(
@@ -1727,63 +1693,6 @@ def to_dict(self, into: type[dict] = dict) -> dict:
return self._to_pandas().to_dict(into=into)
-# TODO: SNOW-1063346
-# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored
-# version of DataFrame, we must keep this override.
-@register_series_accessor("_create_or_update_from_compiler")
-def _create_or_update_from_compiler(self, new_query_compiler, inplace=False):
- """
- Return or update a Series with given `new_query_compiler`.
-
- Parameters
- ----------
- new_query_compiler : PandasQueryCompiler
- QueryCompiler to use to manage the data.
- inplace : bool, default: False
- Whether or not to perform update or creation inplace.
-
- Returns
- -------
- Series, DataFrame or None
- None if update was done, Series or DataFrame otherwise.
- """
- # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
- assert (
- isinstance(new_query_compiler, type(self._query_compiler))
- or type(new_query_compiler) in self._query_compiler.__class__.__bases__
- ), f"Invalid Query Compiler object: {type(new_query_compiler)}"
- if not inplace and new_query_compiler.is_series_like():
- return self.__constructor__(query_compiler=new_query_compiler)
- elif not inplace:
- # This can happen with things like `reset_index` where we can add columns.
- from snowflake.snowpark.modin.pandas.dataframe import DataFrame
-
- return DataFrame(query_compiler=new_query_compiler)
- else:
- self._update_inplace(new_query_compiler=new_query_compiler)
-
-
-# TODO: SNOW-1063346
-# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored
-# version of DataFrame, we must keep this override.
-@register_series_accessor("to_frame")
-def to_frame(self, name: Hashable = no_default) -> DataFrame: # noqa: PR01, RT01, D200
- """
- Convert Series to {label -> value} dict or dict-like object.
- """
- # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
- from snowflake.snowpark.modin.pandas.dataframe import DataFrame
-
- if name is None:
- name = no_default
-
- self_cp = self.copy()
- if name is not no_default:
- self_cp.name = name
-
- return DataFrame(self_cp)
-
-
@register_series_accessor("to_numpy")
def to_numpy(
self,
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py
index 96e2913f556..1cd5e31c63f 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py
@@ -28,11 +28,11 @@
import numpy as np
import pandas as native_pd
+from modin.pandas import DataFrame, Series
from pandas._libs import lib
from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable
from pandas.core.dtypes.common import is_timedelta64_dtype
-from snowflake.snowpark.modin.pandas import DataFrame, Series
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)
@@ -392,12 +392,11 @@ def to_pytimedelta(self) -> np.ndarray:
datetime.timedelta(days=3)], dtype=object)
"""
- @timedelta_index_not_implemented()
def mean(
self, *, skipna: bool = True, axis: AxisInt | None = 0
- ) -> native_pd.Timestamp:
+ ) -> native_pd.Timedelta:
"""
- Return the mean value of the Array.
+ Return the mean value of the Timedelta values.
Parameters
----------
@@ -407,17 +406,46 @@ def mean(
Returns
-------
- scalar Timestamp
+ scalar Timedelta
+
+ Examples
+ --------
+ >>> idx = pd.to_timedelta([1, 2, 3, 1], unit='D')
+ >>> idx
+ TimedeltaIndex(['1 days', '2 days', '3 days', '1 days'], dtype='timedelta64[ns]', freq=None)
+ >>> idx.mean()
+ Timedelta('1 days 18:00:00')
See Also
--------
numpy.ndarray.mean : Returns the average of array elements along a given axis.
Series.mean : Return the mean value in a Series.
-
- Notes
- -----
- mean is only defined for Datetime and Timedelta dtypes, not for Period.
"""
+ if axis:
+ # Native pandas raises IndexError: tuple index out of range
+ # We raise a different more user-friendly error message.
+ raise ValueError(
+ f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'"
+ )
+ pandas_dataframe_result = (
+ # reset_index(drop=False) copies the index column of
+ # self._query_compiler into a new data column. Use `drop=False`
+ # so that we don't have to use SQL row_number() to generate a new
+ # index column.
+ self._query_compiler.reset_index(drop=False)
+ # Aggregate the data column.
+ .agg("mean", axis=0, args=(), kwargs={"skipna": skipna})
+ # convert the query compiler to a pandas dataframe with
+ # dimensions 1x1 (note that the frame has a single row even
+ # if `self` is empty.)
+ .to_pandas()
+ )
+ assert pandas_dataframe_result.shape == (
+ 1,
+ 1,
+ ), "Internal error: aggregation result is not 1x1."
+ # Return the only element in the frame.
+ return pandas_dataframe_result.iloc[0, 0]
@timedelta_index_not_implemented()
def as_unit(self, unit: str) -> TimedeltaIndex:
diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py
index 785a492ca89..f3102115a32 100644
--- a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py
+++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py
@@ -42,3 +42,17 @@
SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE = (
"Scalar key incompatible with {} value"
)
+
+DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = (
+ "Currently do not support Series or list-like keys with range-like values"
+)
+
+DF_SETITEM_SLICE_AS_SCALAR_VALUE = (
+ "Currently do not support assigning a slice value as if it's a scalar value"
+)
+
+DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = (
+ "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark "
+ "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which "
+ "can work on the entire DataFrame in one shot."
+)
diff --git a/src/snowflake/snowpark/modin/utils.py b/src/snowflake/snowpark/modin/utils.py
index b1027f00e33..b3446ca0362 100644
--- a/src/snowflake/snowpark/modin/utils.py
+++ b/src/snowflake/snowpark/modin/utils.py
@@ -1171,7 +1171,7 @@ def validate_int_kwarg(value: int, arg_name: str, float_allowed: bool = False) -
def doc_replace_dataframe_with_link(_obj: Any, doc: str) -> str:
"""
Helper function to be passed as the `modify_doc` parameter to `_inherit_docstrings`. This replaces
- all unqualified instances of "DataFrame" with ":class:`~snowflake.snowpark.pandas.DataFrame`" to
+ all unqualified instances of "DataFrame" with ":class:`~modin.pandas.DataFrame`" to
prevent it from linking automatically to snowflake.snowpark.DataFrame: see SNOW-1233342.
To prevent it from overzealously replacing examples in doctests or already-qualified paths, it
diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py
index c6faa5c9b3b..a586cb7c000 100644
--- a/src/snowflake/snowpark/session.py
+++ b/src/snowflake/snowpark/session.py
@@ -221,6 +221,16 @@
_PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION = (
"PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION"
)
+_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND = (
+ "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND"
+)
+_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = (
+ "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND"
+)
+# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
+# in Snowflake. This is the limit where we start seeing compilation errors.
+DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000
+DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND = 12_000_000
WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None
@@ -575,14 +585,22 @@ def __init__(
_PYTHON_SNOWPARK_USE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION, False
)
)
+ # The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
+ # in Snowflake. This is the limit where we start seeing compilation errors.
+ self._large_query_breakdown_complexity_bounds: Tuple[int, int] = (
+ self._conn._get_client_side_session_parameter(
+ _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND,
+ DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND,
+ ),
+ self._conn._get_client_side_session_parameter(
+ _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_UPPER_BOUND,
+ DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND,
+ ),
+ )
self._custom_package_usage_config: Dict = {}
self._conf = self.RuntimeConfig(self, options or {})
- self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None
self._runtime_version_from_requirement: str = None
self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self)
- if self._auto_clean_up_temp_table_enabled:
- self._temp_table_auto_cleaner.start()
-
_logger.info("Snowpark Session information: %s", self._session_info)
def __enter__(self):
@@ -621,8 +639,8 @@ def close(self) -> None:
raise SnowparkClientExceptionMessages.SERVER_FAILED_CLOSE_SESSION(str(ex))
finally:
try:
- self._conn.close()
self._temp_table_auto_cleaner.stop()
+ self._conn.close()
_logger.info("Closed session: %s", self._session_id)
finally:
_remove_session(self)
@@ -656,10 +674,33 @@ def auto_clean_up_temp_table_enabled(self) -> bool:
:meth:`DataFrame.cache_result` in the current session when the DataFrame is no longer referenced (i.e., gets garbage collected).
The default value is ``False``.
+ Example::
+
+ >>> import gc
+ >>>
+ >>> def f(session: Session) -> str:
+ ... df = session.create_dataframe(
+ ... [[1, 2], [3, 4]], schema=["a", "b"]
+ ... ).cache_result()
+ ... return df.table_name
+ ...
+ >>> session.auto_clean_up_temp_table_enabled = True
+ >>> table_name = f(session)
+ >>> assert table_name
+ >>> gc.collect() # doctest: +SKIP
+ >>>
+ >>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced
+ >>> # outside the function
+ >>> session.sql(f"show tables like '{table_name}'").count()
+ 0
+
+ >>> session.auto_clean_up_temp_table_enabled = False
+
Note:
- Even if this parameter is ``False``, Snowpark still records temporary tables when
- their corresponding DataFrame are garbage collected. Therefore, if you turn it on in the middle of your session or after turning it off,
- the target temporary tables will still be cleaned up accordingly.
+ Temporary tables will only be dropped if this parameter is enabled during garbage collection.
+ If a temporary table is no longer referenced when the parameter is on, it will be dropped during garbage collection.
+ However, if garbage collection occurs while the parameter is off, the table will not be removed.
+ Note that Python's garbage collection is triggered opportunistically, with no guaranteed timing.
"""
return self._auto_clean_up_temp_table_enabled
@@ -667,6 +708,10 @@ def auto_clean_up_temp_table_enabled(self) -> bool:
def large_query_breakdown_enabled(self) -> bool:
return self._large_query_breakdown_enabled
+ @property
+ def large_query_breakdown_complexity_bounds(self) -> Tuple[int, int]:
+ return self._large_query_breakdown_complexity_bounds
+
@property
def custom_package_usage_config(self) -> Dict:
"""Get or set configuration parameters related to usage of custom Python packages in Snowflake.
@@ -753,11 +798,6 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None:
self._session_id, value
)
self._auto_clean_up_temp_table_enabled = value
- is_alive = self._temp_table_auto_cleaner.is_alive()
- if value and not is_alive:
- self._temp_table_auto_cleaner.start()
- elif not value and is_alive:
- self._temp_table_auto_cleaner.stop()
else:
raise ValueError(
"value for auto_clean_up_temp_table_enabled must be True or False!"
@@ -782,6 +822,24 @@ def large_query_breakdown_enabled(self, value: bool) -> None:
"value for large_query_breakdown_enabled must be True or False!"
)
+ @large_query_breakdown_complexity_bounds.setter
+ def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None:
+ """Set the lower and upper bounds for the complexity score used in large query breakdown optimization."""
+
+ if len(value) != 2:
+ raise ValueError(
+ f"Expecting a tuple of two integers. Got a tuple of length {len(value)}"
+ )
+ if value[0] >= value[1]:
+ raise ValueError(
+ f"Expecting a tuple of lower and upper bound with the lower bound less than the upper bound. Got (lower, upper) = ({value[0], value[1]})"
+ )
+ self._conn._telemetry_client.send_large_query_breakdown_update_complexity_bounds(
+ self._session_id, value[0], value[1]
+ )
+
+ self._large_query_breakdown_complexity_bounds = value
+
@custom_package_usage_config.setter
@experimental_parameter(version="1.6.0")
def custom_package_usage_config(self, config: Dict) -> None:
@@ -1649,8 +1707,8 @@ def _upload_unsupported_packages(
try:
# Setup a temporary directory and target folder where pip install will take place.
- self._tmpdir_handler = tempfile.TemporaryDirectory()
- tmpdir = self._tmpdir_handler.name
+ tmpdir_handler = tempfile.TemporaryDirectory()
+ tmpdir = tmpdir_handler.name
target = os.path.join(tmpdir, "unsupported_packages")
if not os.path.exists(target):
os.makedirs(target)
@@ -1735,9 +1793,7 @@ def _upload_unsupported_packages(
for requirement in supported_dependencies + new_dependencies
]
)
- metadata_local_path = os.path.join(
- self._tmpdir_handler.name, metadata_file
- )
+ metadata_local_path = os.path.join(tmpdir_handler.name, metadata_file)
with open(metadata_local_path, "w") as file:
for key, value in metadata.items():
file.write(f"{key},{value}\n")
@@ -1773,9 +1829,8 @@ def _upload_unsupported_packages(
f"-third-party-packages-from-anaconda-in-a-udf."
)
finally:
- if self._tmpdir_handler:
- self._tmpdir_handler.cleanup()
- self._tmpdir_handler = None
+ if tmpdir_handler:
+ tmpdir_handler.cleanup()
return supported_dependencies + new_dependencies
@@ -3094,7 +3149,9 @@ def _use_object(self, object_name: str, object_type: str) -> None:
# we do not validate here
object_type = match.group(1)
object_name = match.group(2)
- setattr(self._conn, f"_active_{object_type}", object_name)
+ mock_conn_lock = self._conn.get_lock()
+ with mock_conn_lock:
+ setattr(self._conn, f"_active_{object_type}", object_name)
else:
self._run_query(query)
else:
diff --git a/src/snowflake/snowpark/version.py b/src/snowflake/snowpark/version.py
index 3955dbbbf33..798a3d902d0 100644
--- a/src/snowflake/snowpark/version.py
+++ b/src/snowflake/snowpark/version.py
@@ -4,4 +4,4 @@
#
# Update this for the versions
-VERSION = (1, 21, 1)
+VERSION = (1, 22, 1)
diff --git a/tests/integ/modin/conftest.py b/tests/integ/modin/conftest.py
index 2f24954e769..a7217b38a50 100644
--- a/tests/integ/modin/conftest.py
+++ b/tests/integ/modin/conftest.py
@@ -715,3 +715,30 @@ def numeric_test_data_4x4():
"C": [7, 10, 13, 16],
"D": [8, 11, 14, 17],
}
+
+
+@pytest.fixture
+def timedelta_native_df() -> pandas.DataFrame:
+ return pandas.DataFrame(
+ {
+ "A": [
+ pd.Timedelta(days=1),
+ pd.Timedelta(days=2),
+ pd.Timedelta(days=3),
+ pd.Timedelta(days=4),
+ ],
+ "B": [
+ pd.Timedelta(minutes=-1),
+ pd.Timedelta(minutes=0),
+ pd.Timedelta(minutes=5),
+ pd.Timedelta(minutes=6),
+ ],
+ "C": [
+ None,
+ pd.Timedelta(nanoseconds=5),
+ pd.Timedelta(nanoseconds=0),
+ pd.Timedelta(nanoseconds=4),
+ ],
+ "D": pandas.to_timedelta([pd.NaT] * 4),
+ }
+ )
diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py
index b018682b6f8..ba68ae13734 100644
--- a/tests/integ/modin/frame/test_aggregate.py
+++ b/tests/integ/modin/frame/test_aggregate.py
@@ -187,6 +187,108 @@ def test_string_sum_with_nulls():
assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"]))
+class TestTimedelta:
+ """Test aggregating dataframes containing timedelta columns."""
+
+ @pytest.mark.parametrize(
+ "func, union_count",
+ [
+ param(
+ lambda df: df.aggregate(["min"]),
+ 0,
+ id="aggregate_list_with_one_element",
+ ),
+ param(lambda df: df.aggregate(x=("A", "max")), 0, id="single_named_agg"),
+ # this works since all results are timedelta and we don't need to do any concats.
+ param(
+ lambda df: df.aggregate({"B": "mean", "A": "sum"}),
+ 0,
+ id="dict_producing_two_timedeltas",
+ ),
+ # this works since even though we need to do concats, all the results are non-timdelta.
+ param(
+ lambda df: df.aggregate(x=("B", "all"), y=("B", "any")),
+ 1,
+ id="named_agg_producing_two_bools",
+ ),
+ # note following aggregation requires transpose
+ param(lambda df: df.aggregate(max), 0, id="aggregate_max"),
+ param(lambda df: df.min(), 0, id="min"),
+ param(lambda df: df.max(), 0, id="max"),
+ param(lambda df: df.count(), 0, id="count"),
+ param(lambda df: df.sum(), 0, id="sum"),
+ param(lambda df: df.mean(), 0, id="mean"),
+ param(lambda df: df.median(), 0, id="median"),
+ param(lambda df: df.std(), 0, id="std"),
+ param(lambda df: df.quantile(), 0, id="single_quantile"),
+ param(lambda df: df.quantile([0.01, 0.99]), 1, id="two_quantiles"),
+ ],
+ )
+ def test_supported_axis_0(self, func, union_count, timedelta_native_df):
+ with SqlCounter(query_count=1, union_count=union_count):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(timedelta_native_df),
+ func,
+ )
+
+ @sql_count_checker(query_count=0)
+ @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1653126")
+ def test_axis_1(self, timedelta_native_df):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(timedelta_native_df), lambda df: df.sum(axis=1)
+ )
+
+ @sql_count_checker(query_count=0)
+ def test_var_invalid(self, timedelta_native_df):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(timedelta_native_df),
+ lambda df: df.var(),
+ expect_exception=True,
+ expect_exception_type=TypeError,
+ assert_exception_equal=False,
+ expect_exception_match=re.escape(
+ "timedelta64 type does not support var operations"
+ ),
+ )
+
+ @sql_count_checker(query_count=0)
+ @pytest.mark.xfail(
+ strict=True,
+ raises=NotImplementedError,
+ reason="requires concat(), which we cannot do with Timedelta.",
+ )
+ @pytest.mark.parametrize(
+ "operation",
+ [
+ lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}),
+ lambda df: df.aggregate({"B": ["count"], "A": "sum", "C": ["max", "min"]}),
+ lambda df: df.aggregate(
+ x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count")
+ ),
+ lambda df: df.aggregate(["min", np.max]),
+ lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")),
+ lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")),
+ lambda df: df.aggregate(
+ {"B": ["idxmax"], "A": "sum", "C": ["max", "idxmin"]}
+ ),
+ ],
+ )
+ def test_agg_requires_concat_with_timedelta(self, timedelta_native_df, operation):
+ eval_snowpark_pandas_result(*create_test_dfs(timedelta_native_df), operation)
+
+ @sql_count_checker(query_count=0)
+ @pytest.mark.xfail(
+ strict=True,
+ raises=NotImplementedError,
+ reason="requires transposing a one-row frame with integer and timedelta.",
+ )
+ def test_agg_produces_timedelta_and_non_timedelta_type(self, timedelta_native_df):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(timedelta_native_df),
+ lambda df: df.aggregate({"B": "idxmax", "A": "sum"}),
+ )
+
+
@pytest.mark.parametrize(
"func, expected_union_count",
[
diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py
index 1014cae44c9..ded0651046c 100644
--- a/tests/integ/modin/frame/test_apply.py
+++ b/tests/integ/modin/frame/test_apply.py
@@ -337,16 +337,6 @@ def f(x, y, z=1) -> int:
class TestNotImplemented:
- @pytest.mark.parametrize(
- "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP
- )
- @sql_count_checker(query_count=0)
- def test_axis_0(self, data, func, return_type):
- snow_df = pd.DataFrame(data)
- msg = "Snowpark pandas apply API doesn't yet support axis == 0"
- with pytest.raises(NotImplementedError, match=msg):
- snow_df.apply(func)
-
@pytest.mark.parametrize("result_type", ["reduce", "expand", "broadcast"])
@sql_count_checker(query_count=0)
def test_result_type(self, result_type):
@@ -554,33 +544,70 @@ def g(v):
]
-TRANSFORM_DATA_FUNC_MAP = [
- [[[0, 1, 2], [1, 2, 3]], lambda x: x + 1],
- [[[0, 1, 2], [1, 2, 3]], np.exp],
- [[[0, 1, 2], [1, 2, 3]], "exp"],
- [[["Leonhard", "Jianzhun"]], lambda x: x + " is awesome!!"],
- [[[1.3, 2.5]], np.sqrt],
- [[[1.3, 2.5]], "sqrt"],
- [[[1.3, 2.5]], np.log],
- [[[1.3, 2.5]], "log"],
- [[[1.3, 2.5]], np.square],
- [[[1.3, 2.5]], "square"],
+@pytest.mark.xfail(
+ strict=True,
+ raises=SnowparkSQLException,
+ reason="SNOW-1650918: Apply on dataframe data columns containing NULL fails with invalid arguments to udtf function",
+)
+@pytest.mark.parametrize(
+ "data, apply_func",
[
- [[None, "abcd"]],
- lambda x: x + " are first 4 letters of alphabet" if x is not None else None,
+ [
+ [[None, "abcd"]],
+ lambda x: x + " are first 4 letters of alphabet" if x is not None else None,
+ ],
+ [
+ [[123, None]],
+ lambda x: x + 100 if x is not None else None,
+ ],
],
- [[[1.5, float("nan")]], lambda x: np.sqrt(x)],
+)
+def test_apply_bug_1650918(data, apply_func):
+ native_df = native_pd.DataFrame(data)
+ snow_df = pd.DataFrame(native_df)
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda x: x.apply(apply_func, axis=1),
+ )
+
+
+TRANSFORM_TEST_MAP = [
+ [[[0, 1, 2], [1, 2, 3]], lambda x: x + 1, 16],
+ [[[0, 1, 2], [1, 2, 3]], np.exp, 16],
+ [[[0, 1, 2], [1, 2, 3]], "exp", None],
+ [[["Leonhard", "Jianzhun"]], lambda x: x + " is awesome!!", 11],
+ [[[1.3, 2.5]], np.sqrt, 11],
+ [[[1.3, 2.5]], "sqrt", None],
+ [[[1.3, 2.5]], np.log, 11],
+ [[[1.3, 2.5]], "log", None],
+ [[[1.3, 2.5]], np.square, 11],
+ [[[1.3, 2.5]], "square", None],
+ [[[1.5, float("nan")]], lambda x: np.sqrt(x), 11],
]
@pytest.mark.modin_sp_precommit
-@pytest.mark.parametrize("data, apply_func", TRANSFORM_DATA_FUNC_MAP)
-@sql_count_checker(query_count=0)
-def test_basic_dataframe_transform(data, apply_func):
- msg = "Snowpark pandas apply API doesn't yet support axis == 0"
- with pytest.raises(NotImplementedError, match=msg):
+@pytest.mark.parametrize("data, apply_func, expected_query_count", TRANSFORM_TEST_MAP)
+def test_basic_dataframe_transform(data, apply_func, expected_query_count):
+ if expected_query_count is None:
+ msg = "Snowpark pandas apply API only supports callables func"
+ with SqlCounter(query_count=0):
+ with pytest.raises(NotImplementedError, match=msg):
+ snow_df = pd.DataFrame(data)
+ snow_df.transform(apply_func)
+ else:
+ msg = "SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function"
+ native_df = native_pd.DataFrame(data)
snow_df = pd.DataFrame(data)
- snow_df.transform(apply_func)
+ with SqlCounter(
+ query_count=expected_query_count,
+ high_count_expected=True,
+ high_count_reason=msg,
+ ):
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.transform(apply_func)
+ )
AGGREGATION_FUNCTIONS = [
@@ -610,7 +637,7 @@ def test_dataframe_transform_invalid_function_name_negative(session):
snow_df = pd.DataFrame([[0, 1, 2], [1, 2, 3]])
with pytest.raises(
NotImplementedError,
- match="Snowpark pandas apply API doesn't yet support axis == 0",
+ match="Snowpark pandas apply API only supports callables func",
):
snow_df.transform("mxyzptlk")
diff --git a/tests/integ/modin/frame/test_apply_axis_0.py b/tests/integ/modin/frame/test_apply_axis_0.py
new file mode 100644
index 00000000000..47fd14d7b98
--- /dev/null
+++ b/tests/integ/modin/frame/test_apply_axis_0.py
@@ -0,0 +1,653 @@
+#
+# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
+#
+
+import datetime
+
+import modin.pandas as pd
+import numpy as np
+import pandas as native_pd
+import pytest
+
+import snowflake.snowpark.modin.plugin # noqa: F401
+from snowflake.snowpark.exceptions import SnowparkSQLException
+from tests.integ.modin.series.test_apply import create_func_with_return_type_hint
+from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
+from tests.integ.modin.utils import (
+ assert_snowpark_pandas_equal_to_pandas,
+ assert_snowpark_pandas_equals_to_pandas_without_dtypecheck,
+ create_test_dfs,
+ eval_snowpark_pandas_result,
+)
+
+# test data which has a python type as return type that is not a pandas Series/pandas DataFrame/tuple/list
+BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP = [
+ [[[1.0, 2.2], [3, np.nan]], np.min, "float"],
+ [[[1.1, 2.2], [3, np.nan]], lambda x: x.sum(), "float"],
+ [[[1.1, 2.2], [3, np.nan]], lambda x: x.size, "int"],
+ [[[1.1, 2.2], [3, np.nan]], lambda x: "0" if x.sum() > 1 else 0, "object"],
+ [[["snow", "flake"], ["data", "cloud"]], lambda x: x[0] + x[1], "str"],
+ [[[True, False], [False, False]], lambda x: True, "bool"],
+ [[[True, False], [False, False]], lambda x: x[0] ^ x[1], "bool"],
+ (
+ [
+ [bytes("snow", "utf-8"), bytes("flake", "utf-8")],
+ [bytes("data", "utf-8"), bytes("cloud", "utf-8")],
+ ],
+ lambda x: (x[0] + x[1]).decode(),
+ "str",
+ ),
+ (
+ [[["a", "b"], ["c", "d"]], [["a", "b"], ["c", "d"]]],
+ lambda x: x[0][1] + x[1][0],
+ "str",
+ ),
+ (
+ [[{"a": "b"}, {"c": "d"}], [{"c": "b"}, {"a": "d"}]],
+ lambda x: str(x[0]) + str(x[1]),
+ "str",
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP
+)
+@pytest.mark.modin_sp_precommit
+def test_axis_0_basic_types_without_type_hints(data, func, return_type):
+ # this test processes functions without type hints and invokes the UDTF solution.
+ native_df = native_pd.DataFrame(data, columns=["A", "b"])
+ snow_df = pd.DataFrame(data, columns=["A", "b"])
+ with SqlCounter(
+ query_count=11,
+ join_count=2,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(snow_df, native_df, lambda x: x.apply(func, axis=0))
+
+
+@pytest.mark.parametrize(
+ "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP
+)
+@pytest.mark.modin_sp_precommit
+def test_axis_0_basic_types_with_type_hints(data, func, return_type):
+ # create explicitly for supported python types UDF with type hints and process via vUDF.
+ native_df = native_pd.DataFrame(data, columns=["A", "b"])
+ snow_df = pd.DataFrame(data, columns=["A", "b"])
+ func_with_type_hint = create_func_with_return_type_hint(func, return_type)
+ # Invoking a single UDF typically requires 3 queries (package management, code upload, UDF registration) upfront.
+ with SqlCounter(
+ query_count=11,
+ join_count=2,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(func_with_type_hint, axis=0)
+ )
+
+
+@pytest.mark.parametrize(
+ "df,row_label",
+ [
+ (
+ native_pd.DataFrame(
+ [[1, 2], [None, 3]], columns=["A", "b"], index=["A", "B"]
+ ),
+ "B",
+ ),
+ (
+ native_pd.DataFrame(
+ [[1, 2], [None, 3]],
+ columns=["A", "b"],
+ index=pd.MultiIndex.from_tuples([(1, 2), (1, 1)]),
+ ),
+ (1, 2),
+ ),
+ ],
+)
+def test_axis_0_index_passed_as_name(df, row_label):
+ # when using apply(axis=1) the original index of the dataframe is passed as name.
+ # test here for this for regular index and multi-index scenario.
+
+ def foo(row) -> str:
+ if row.name == row_label:
+ return "MATCHING LABEL"
+ else:
+ return "NO MATCH"
+
+ snow_df = pd.DataFrame(df)
+ with SqlCounter(
+ query_count=11,
+ join_count=2,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=0))
+
+
+@sql_count_checker(
+ query_count=11,
+ join_count=3,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_return_series():
+ snow_df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"])
+ native_df = native_pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"])
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda x: x.apply(lambda x: native_pd.Series([1, 2], index=["C", "d"]), axis=0),
+ )
+
+
+@sql_count_checker(
+ query_count=11,
+ join_count=3,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_return_series_with_different_label_results():
+ df = native_pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"])
+ snow_df = pd.DataFrame(df)
+
+ eval_snowpark_pandas_result(
+ snow_df,
+ df,
+ lambda df: df.apply(
+ lambda x: native_pd.Series([1, 2], index=["a", "b"])
+ if x.sum() > 3
+ else native_pd.Series([0, 1, 2], index=["c", "a", "b"]),
+ axis=0,
+ ),
+ )
+
+
+@sql_count_checker(query_count=6, join_count=1, udtf_count=1)
+def test_axis_0_return_single_scalar_series():
+ native_df = native_pd.DataFrame([1])
+ snow_df = pd.DataFrame(native_df)
+
+ def apply_func(x):
+ return native_pd.Series([1], index=["xyz"])
+
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(apply_func, axis=0)
+ )
+
+
+@sql_count_checker(query_count=3)
+def test_axis_0_return_dataframe_not_supported():
+ snow_df = pd.DataFrame([1])
+
+ # Note that pands returns failure "ValueError: If using all scalar values, you must pass an index" which
+ # doesn't explain this isn't supported. We go with the default returned by pandas in this case.
+ with pytest.raises(
+ SnowparkSQLException, match="The truth value of a DataFrame is ambiguous."
+ ):
+ # return value
+ snow_df.apply(lambda x: native_pd.DataFrame([1, 2]), axis=0).to_pandas()
+
+
+class TestNotImplemented:
+ @pytest.mark.parametrize("result_type", ["reduce", "expand", "broadcast"])
+ @sql_count_checker(query_count=0)
+ def test_result_type(self, result_type):
+ snow_df = pd.DataFrame([[1, 2], [3, 4]])
+ msg = "Snowpark pandas apply API doesn't yet support 'result_type' parameter"
+ with pytest.raises(NotImplementedError, match=msg):
+ snow_df.apply(lambda x: [1, 2], axis=0, result_type=result_type)
+
+ @sql_count_checker(query_count=0)
+ def test_axis_1_apply_args_kwargs_with_snowpandas_object(self):
+ def f(x, y=None) -> native_pd.Series:
+ return x + (y if y is not None else 0)
+
+ snow_df = pd.DataFrame([[1, 2], [3, 4]])
+ msg = "Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'"
+ with pytest.raises(NotImplementedError, match=msg):
+ snow_df.apply(f, axis=0, args=(pd.Series([1, 2]),))
+ with pytest.raises(NotImplementedError, match=msg):
+ snow_df.apply(f, axis=0, y=pd.Series([1, 2]))
+
+
+TEST_INDEX_1 = native_pd.MultiIndex.from_tuples(
+ list(zip(*[["a", "b"], ["x", "y"]])),
+ names=["first", "last"],
+)
+
+
+TEST_INDEX_WITH_NULL_1 = native_pd.MultiIndex.from_tuples(
+ list(zip(*[[None, "b"], ["x", None]])),
+ names=["first", "last"],
+)
+
+
+TEST_INDEX_2 = native_pd.MultiIndex.from_tuples(
+ list(zip(*[["AA", "BB"], ["XX", "YY"]])),
+ names=["FOO", "BAR"],
+)
+
+TEST_INDEX_WITH_NULL_2 = native_pd.MultiIndex.from_tuples(
+ list(zip(*[[None, "BB"], ["XX", None]])),
+ names=["FOO", "BAR"],
+)
+
+
+TEST_COLUMNS_1 = native_pd.MultiIndex.from_tuples(
+ list(
+ zip(
+ *[
+ ["car", "motorcycle", "bike", "bus"],
+ ["blue", "green", "red", "yellow"],
+ ]
+ )
+ ),
+ names=["vehicle", "color"],
+)
+
+
+@pytest.mark.parametrize(
+ "apply_func, expected_join_count, expected_union_count",
+ [
+ [lambda x: [1, 2], 3, 0],
+ [lambda x: x + 1 if x is not None else None, 3, 0],
+ [lambda x: x.min(), 2, 1],
+ ],
+)
+def test_axis_0_series_basic(apply_func, expected_join_count, expected_union_count):
+ native_df = native_pd.DataFrame(
+ [[1.1, 2.2], [3.0, None]], index=pd.Index([2, 3]), columns=["A", "b"]
+ )
+ snow_df = pd.DataFrame(native_df)
+ with SqlCounter(
+ query_count=11,
+ join_count=expected_join_count,
+ udtf_count=2,
+ union_count=expected_union_count,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda df: df.apply(apply_func, axis=0),
+ )
+
+
+@sql_count_checker(query_count=5, join_count=1, udtf_count=1)
+def test_groupby_apply_constant_output():
+ native_df = native_pd.DataFrame([1, 2])
+ native_df["fg"] = 0
+ snow_df = pd.DataFrame(native_df)
+
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda df: df.groupby(by=["fg"], axis=0).apply(lambda x: [1, 2]),
+ )
+
+
+@sql_count_checker(
+ query_count=11,
+ join_count=3,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_return_list():
+ snow_df = pd.DataFrame([[1, 2], [3, 4]])
+ native_df = native_pd.DataFrame([[1, 2], [3, 4]])
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(lambda x: [1, 2], axis=0)
+ )
+
+
+@pytest.mark.parametrize(
+ "apply_func",
+ [
+ lambda x: -x,
+ lambda x: native_pd.Series([1, 2], index=TEST_INDEX_1),
+ lambda x: native_pd.Series([3, 4], index=TEST_INDEX_2),
+ lambda x: native_pd.Series([1, 2], index=TEST_INDEX_WITH_NULL_1),
+ lambda x: native_pd.Series([1, 2], index=TEST_INDEX_WITH_NULL_1),
+ ],
+)
+@sql_count_checker(
+ query_count=21,
+ join_count=7,
+ udtf_count=4,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_multi_index_column_labels(apply_func):
+ data = [[i + j for j in range(0, 4)] for i in range(0, 4)]
+
+ native_df = native_pd.DataFrame(data, columns=TEST_COLUMNS_1)
+ snow_df = pd.DataFrame(data, columns=TEST_COLUMNS_1)
+
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(apply_func, axis=0)
+ )
+
+
+@sql_count_checker(
+ query_count=21,
+ join_count=7,
+ udtf_count=4,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_multi_index_column_labels_with_different_results():
+ data = [[i + j for j in range(0, 4)] for i in range(0, 4)]
+
+ df = native_pd.DataFrame(data, columns=TEST_COLUMNS_1)
+ snow_df = pd.DataFrame(df)
+
+ apply_func = (
+ lambda x: native_pd.Series([1, 2], index=TEST_INDEX_1)
+ if min(x) == 0
+ else native_pd.Series([3, 4], index=TEST_INDEX_2)
+ )
+
+ eval_snowpark_pandas_result(snow_df, df, lambda df: df.apply(apply_func, axis=0))
+
+
+@pytest.mark.parametrize(
+ "data, func, expected_result",
+ [
+ [
+ [
+ [datetime.date(2023, 1, 1), None],
+ [datetime.date(2022, 12, 31), datetime.date(2021, 1, 9)],
+ ],
+ lambda x: x.dt.day,
+ native_pd.DataFrame([[1, np.nan], [31, 9.0]]),
+ ],
+ [
+ [
+ [datetime.time(1, 2, 3), None],
+ [datetime.time(1, 2, 3, 1), datetime.time(1)],
+ ],
+ lambda x: x.dt.seconds,
+ native_pd.DataFrame([[3723, np.nan], [3723, 3600]]),
+ ],
+ [
+ [
+ [datetime.datetime(2023, 1, 1, 1, 2, 3), None],
+ [
+ datetime.datetime(2022, 12, 31, 1, 2, 3, 1),
+ datetime.datetime(
+ 2023, 1, 1, 1, 2, 3, tzinfo=datetime.timezone.utc
+ ),
+ ],
+ ],
+ lambda x: x.astype(str),
+ native_pd.DataFrame(
+ [
+ ["2023-01-01 01:02:03.000000", "NaT"],
+ ["2022-12-31 01:02:03.000001", "2023-01-01 01:02:03+00:00"],
+ ]
+ ),
+ ],
+ ],
+)
+@sql_count_checker(
+ query_count=11,
+ join_count=3,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_date_time_timestamp_type(data, func, expected_result):
+ snow_df = pd.DataFrame(data)
+ result = snow_df.apply(func, axis=0)
+
+ assert_snowpark_pandas_equal_to_pandas(result, expected_result)
+
+
+@pytest.mark.parametrize(
+ "native_df, func",
+ [
+ (
+ native_pd.DataFrame([[1, 2], [3, 4]], index=["a", "b"]),
+ lambda x: x["a"] + x["b"],
+ ),
+ (
+ native_pd.DataFrame(
+ [[1, 5], [2, 6], [3, 7], [4, 8]],
+ index=native_pd.MultiIndex.from_tuples(
+ [("baz", "A"), ("baz", "B"), ("zoo", "A"), ("zoo", "B")]
+ ),
+ ),
+ lambda x: x["baz", "B"] * x["zoo", "A"],
+ ),
+ ],
+)
+@sql_count_checker(
+ query_count=11,
+ join_count=2,
+ udtf_count=2,
+ union_count=1,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_index_labels(native_df, func):
+ snow_df = pd.DataFrame(native_df)
+ eval_snowpark_pandas_result(snow_df, native_df, lambda x: x.apply(func, axis=0))
+
+
+@sql_count_checker(
+ query_count=11,
+ join_count=2,
+ udtf_count=2,
+ union_count=1,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+)
+def test_axis_0_raw():
+ snow_df = pd.DataFrame([[1, 2], [3, 4]])
+ native_df = native_pd.DataFrame([[1, 2], [3, 4]])
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(lambda x: str(type(x)), axis=0, raw=True)
+ )
+
+
+def test_axis_0_apply_args_kwargs():
+ def f(x, y, z=1) -> int:
+ return x.sum() + y + z
+
+ native_df = native_pd.DataFrame([[1, 2], [3, 4]])
+ snow_df = pd.DataFrame([[1, 2], [3, 4]])
+
+ with SqlCounter(query_count=3):
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda x: x.apply(f, axis=0),
+ expect_exception=True,
+ expect_exception_type=SnowparkSQLException,
+ expect_exception_match="missing 1 required positional argument",
+ assert_exception_equal=False,
+ )
+
+ with SqlCounter(
+ query_count=11,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(f, axis=0, args=(1,))
+ )
+
+ with SqlCounter(
+ query_count=11,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(
+ snow_df, native_df, lambda x: x.apply(f, axis=0, args=(1,), z=2)
+ )
+
+ with SqlCounter(query_count=3):
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda x: x.apply(f, axis=0, args=(1,), z=2, v=3),
+ expect_exception=True,
+ expect_exception_type=SnowparkSQLException,
+ expect_exception_match="got an unexpected keyword argument",
+ assert_exception_equal=False,
+ )
+
+
+@pytest.mark.parametrize("data", [{"a": [1], "b": [2]}, {"a": [2], "b": [3]}])
+def test_apply_axis_0_with_if_where_duplicates_not_executed(data):
+ df = native_pd.DataFrame(data)
+ snow_df = pd.DataFrame(df)
+
+ def foo(x):
+ return native_pd.Series(
+ [1, 2, 3], index=["C", "A", "E"] if x.sum() > 3 else ["A", "E", "E"]
+ )
+
+ with SqlCounter(
+ query_count=11,
+ join_count=3,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=0))
+
+
+@pytest.mark.parametrize(
+ "return_value",
+ [
+ native_pd.Series(["a", np.int64(3)]),
+ ["a", np.int64(3)],
+ np.int64(3),
+ ],
+)
+@sql_count_checker(query_count=6, join_count=1, udtf_count=1)
+def test_numpy_integers_in_return_values_snow_1227264(return_value):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(["a"]), lambda df: df.apply(lambda row: return_value, axis=0)
+ )
+
+
+@pytest.mark.xfail(
+ strict=True,
+ raises=SnowparkSQLException,
+ reason="SNOW-1650918: Apply on dataframe data columns containing NULL fails with invalid arguments to udtf function",
+)
+@pytest.mark.parametrize(
+ "data, apply_func",
+ [
+ [
+ [[None, "abcd"]],
+ lambda x: x + " are first 4 letters of alphabet" if x is not None else None,
+ ],
+ [
+ [[123, None]],
+ lambda x: x + 100 if x is not None else None,
+ ],
+ ],
+)
+def test_apply_axis_0_bug_1650918(data, apply_func):
+ native_df = native_pd.DataFrame(data)
+ snow_df = pd.DataFrame(native_df)
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ lambda x: x.apply(apply_func, axis=0),
+ )
+
+
+def test_apply_nested_series_negative():
+ snow_df = pd.DataFrame([[1, 2], [3, 4]])
+
+ with SqlCounter(
+ query_count=10,
+ join_count=2,
+ udtf_count=2,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ with pytest.raises(
+ NotImplementedError,
+ match=r"Nested pd.Series in result is not supported in DataFrame.apply\(axis=0\)",
+ ):
+ snow_df.apply(
+ lambda ser: 99 if ser.sum() == 4 else native_pd.Series([1, 2]), axis=0
+ ).to_pandas()
+
+ snow_df2 = pd.DataFrame([[1, 2, 3]])
+
+ with SqlCounter(
+ query_count=15,
+ join_count=3,
+ udtf_count=3,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ with pytest.raises(
+ NotImplementedError,
+ match=r"Nested pd.Series in result is not supported in DataFrame.apply\(axis=0\)",
+ ):
+ snow_df2.apply(
+ lambda ser: 99
+ if ser.sum() == 2
+ else native_pd.Series([100], index=["a"]),
+ axis=0,
+ ).to_pandas()
+
+
+import scipy.stats # noqa: E402
+
+
+@pytest.mark.parametrize(
+ "packages,expected_query_count",
+ [
+ (["scipy", "numpy"], 26),
+ (["scipy>1.1", "numpy<2.0"], 26),
+ # TODO: SNOW-1478188 Re-enable quarantined tests for 8.23
+ # [scipy, np], 9),
+ ],
+)
+def test_apply_axis0_with_3rd_party_libraries_and_decorator(
+ packages, expected_query_count
+):
+ data = [[1, 2, 3, 4, 5], [7, -20, 4.0, 7.0, None]]
+
+ with SqlCounter(
+ query_count=expected_query_count,
+ high_count_expected=True,
+ high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function",
+ ):
+ try:
+ pd.session.custom_package_usage_config["enabled"] = True
+ pd.session.add_packages(packages)
+
+ df = pd.DataFrame(data)
+
+ def func(row):
+ return np.dot(row, scipy.stats.norm.pdf(row))
+
+ snow_ans = df.apply(func, axis=0)
+ finally:
+ pd.session.clear_packages()
+ pd.session.clear_imports()
+
+ # same in native pandas:
+ native_df = native_pd.DataFrame(data)
+ native_ans = native_df.apply(func, axis=0)
+
+ assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(snow_ans, native_ans)
diff --git a/tests/integ/modin/frame/test_describe.py b/tests/integ/modin/frame/test_describe.py
index a9668c5794f..4f1882d441d 100644
--- a/tests/integ/modin/frame/test_describe.py
+++ b/tests/integ/modin/frame/test_describe.py
@@ -358,3 +358,18 @@ def test_describe_object_file(resources_path):
df = pd.read_csv(test_files.test_concat_file1_csv)
native_df = df.to_pandas()
eval_snowpark_pandas_result(df, native_df, lambda x: x.describe(include="O"))
+
+
+@sql_count_checker(query_count=0)
+@pytest.mark.xfail(
+ strict=True,
+ raises=NotImplementedError,
+ reason="requires concat(), which we cannot do with Timedelta.",
+)
+def test_timedelta(timedelta_native_df):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(
+ timedelta_native_df,
+ ),
+ lambda df: df.describe(),
+ )
diff --git a/tests/integ/modin/frame/test_dtypes.py b/tests/integ/modin/frame/test_dtypes.py
index c3773bdd6de..b078b31f6c5 100644
--- a/tests/integ/modin/frame/test_dtypes.py
+++ b/tests/integ/modin/frame/test_dtypes.py
@@ -351,7 +351,7 @@ def test_insert_multiindex_multi_label(label1, label2):
native_pd.Timestamp(1513393355, unit="s", tz="US/Pacific"),
],
"datetime64[ns, America/Los_Angeles]",
- "datetime64[ns, America/Los_Angeles]",
+ "datetime64[ns, UTC-08:00]",
"datetime64[ns]",
),
(
@@ -372,7 +372,7 @@ def test_insert_multiindex_multi_label(label1, label2):
native_pd.Timestamp(1513393355, unit="s", tz="US/Pacific"),
],
"object",
- "datetime64[ns, America/Los_Angeles]",
+ "datetime64[ns, UTC-08:00]",
"datetime64[ns]",
),
],
diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py
index 72fe88968bc..87041060bd2 100644
--- a/tests/integ/modin/frame/test_idxmax_idxmin.py
+++ b/tests/integ/modin/frame/test_idxmax_idxmin.py
@@ -196,8 +196,18 @@ def test_idxmax_idxmin_with_dates(func, axis):
@sql_count_checker(query_count=1)
@pytest.mark.parametrize("func", ["idxmax", "idxmin"])
-@pytest.mark.parametrize("axis", [0, 1])
-@pytest.mark.xfail(reason="SNOW-1625380 TODO")
+@pytest.mark.parametrize(
+ "axis",
+ [
+ 0,
+ pytest.param(
+ 1,
+ marks=pytest.mark.xfail(
+ strict=True, raises=NotImplementedError, reason="SNOW-1653126"
+ ),
+ ),
+ ],
+)
def test_idxmax_idxmin_with_timedelta(func, axis):
native_df = native_pd.DataFrame(
data={
diff --git a/tests/integ/modin/frame/test_info.py b/tests/integ/modin/frame/test_info.py
index 2a096e76fdc..fbbf8dfe041 100644
--- a/tests/integ/modin/frame/test_info.py
+++ b/tests/integ/modin/frame/test_info.py
@@ -13,9 +13,7 @@
def _assert_info_lines_equal(modin_info: list[str], pandas_info: list[str]):
# class is different
- assert (
- modin_info[0] == ""
- )
+ assert modin_info[0] == ""
assert pandas_info[0] == ""
# index is different
diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py
index be51b8c9ae6..105bf475f3a 100644
--- a/tests/integ/modin/frame/test_loc.py
+++ b/tests/integ/modin/frame/test_loc.py
@@ -4072,3 +4072,22 @@ def test_df_loc_get_with_timedelta_and_none_key():
# Compare with an empty DataFrame, since native pandas raises a KeyError.
expected_df = native_pd.DataFrame()
assert_frame_equal(snow_df.loc[None], expected_df, check_column_type=False)
+
+
+@sql_count_checker(query_count=0)
+def test_df_loc_invalid_key():
+ # Bug fix: SNOW-1320674
+ native_df = native_pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
+ snow_df = pd.DataFrame(native_df)
+
+ def op(df):
+ df["C"] = df["A"] / df["D"]
+
+ eval_snowpark_pandas_result(
+ snow_df,
+ native_df,
+ op,
+ expect_exception=True,
+ expect_exception_type=KeyError,
+ expect_exception_match="D",
+ )
diff --git a/tests/integ/modin/frame/test_nunique.py b/tests/integ/modin/frame/test_nunique.py
index d0cad8ec2ad..78098d34386 100644
--- a/tests/integ/modin/frame/test_nunique.py
+++ b/tests/integ/modin/frame/test_nunique.py
@@ -11,8 +11,13 @@
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result
-TEST_LABELS = np.array(["A", "B", "C", "D"])
-TEST_DATA = [[0, 1, 2, 3], [0, 0, 0, 0], [None, 0, None, 0], [None, None, None, None]]
+TEST_LABELS = np.array(["A", "B", "C", "D", "E"])
+TEST_DATA = [
+ [0, 1, 2, 3, pd.Timedelta(4)],
+ [0, 0, 0, 0, pd.Timedelta(0)],
+ [None, 0, None, 0, pd.Timedelta(0)],
+ [None, None, None, None, None],
+]
# which original dataframe (constructed from slicing) to test for
TEST_SLICES = [
@@ -80,7 +85,7 @@ def test_dataframe_nunique_no_columns(native_df):
[
pytest.param(None, id="default_columns"),
pytest.param(
- [["bar", "bar", "baz", "foo"], ["one", "two", "one", "two"]],
+ [["bar", "bar", "baz", "foo", "foo"], ["one", "two", "one", "two", "one"]],
id="2D_columns",
),
],
diff --git a/tests/integ/modin/frame/test_skew.py b/tests/integ/modin/frame/test_skew.py
index 72fad6cebdc..94b7fd79c24 100644
--- a/tests/integ/modin/frame/test_skew.py
+++ b/tests/integ/modin/frame/test_skew.py
@@ -8,7 +8,11 @@
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import sql_count_checker
-from tests.integ.modin.utils import assert_series_equal
+from tests.integ.modin.utils import (
+ assert_series_equal,
+ create_test_dfs,
+ eval_snowpark_pandas_result,
+)
@sql_count_checker(query_count=1)
@@ -62,16 +66,22 @@ def test_skew_basic():
},
"kwargs": {"numeric_only": True, "skipna": True},
},
+ {
+ "frame": {
+ "A": [pd.Timedelta(1)],
+ },
+ "kwargs": {
+ "numeric_only": True,
+ },
+ },
],
)
@sql_count_checker(query_count=1)
def test_skew(data):
- native_df = native_pd.DataFrame(data["frame"])
- snow_df = pd.DataFrame(native_df)
- assert_series_equal(
- snow_df.skew(**data["kwargs"]),
- native_df.skew(**data["kwargs"]),
- rtol=1.0e-5,
+ eval_snowpark_pandas_result(
+ *create_test_dfs(data["frame"]),
+ lambda df: df.skew(**data["kwargs"]),
+ rtol=1.0e-5
)
@@ -103,6 +113,14 @@ def test_skew(data):
},
"kwargs": {"level": 2},
},
+ {
+ "frame": {
+ "A": [pd.Timedelta(1)],
+ },
+ "kwargs": {
+ "numeric_only": False,
+ },
+ },
],
)
@sql_count_checker(query_count=0)
diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py
index d5234dfbdb5..df8df44d47c 100644
--- a/tests/integ/modin/groupby/test_all_any.py
+++ b/tests/integ/modin/groupby/test_all_any.py
@@ -14,7 +14,11 @@
import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark.exceptions import SnowparkSQLException
from tests.integ.modin.sql_counter import sql_count_checker
-from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result
+from tests.integ.modin.utils import (
+ assert_frame_equal,
+ create_test_dfs,
+ eval_snowpark_pandas_result,
+)
@pytest.mark.parametrize(
@@ -109,3 +113,27 @@ def test_all_any_chained():
lambda df: df.apply(lambda ser: ser.str.len())
)
)
+
+
+@sql_count_checker(query_count=1)
+def test_timedelta_any_with_nulls():
+ """
+ Test this case separately because pandas behavior is different from Snowpark pandas behavior.
+
+ pandas bug that does not apply to Snowpark pandas:
+ https://github.com/pandas-dev/pandas/issues/59712
+ """
+ snow_df, native_df = create_test_dfs(
+ {
+ "key": ["a"],
+ "A": native_pd.Series([pd.NaT], dtype="timedelta64[ns]"),
+ },
+ )
+ assert_frame_equal(
+ native_df.groupby("key").any(),
+ native_pd.DataFrame({"A": [True]}, index=native_pd.Index(["a"], name="key")),
+ )
+ assert_frame_equal(
+ snow_df.groupby("key").any(),
+ native_pd.DataFrame({"A": [False]}, index=native_pd.Index(["a"], name="key")),
+ )
diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py
index 09acd49bb21..cbf5b75d48c 100644
--- a/tests/integ/modin/groupby/test_groupby_basic_agg.py
+++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py
@@ -1096,60 +1096,81 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df):
)
-@pytest.mark.parametrize(
- "agg_func",
- [
- "count",
- "sum",
- "mean",
- "median",
- "std",
- ],
-)
-@pytest.mark.parametrize("by", ["A", "B"])
-@sql_count_checker(query_count=1)
-def test_timedelta(agg_func, by):
- native_df = native_pd.DataFrame(
- {
- "A": native_pd.to_timedelta(
- ["1 days 06:05:01.00003", "16us", "nan", "16us"]
- ),
- "B": [8, 8, 12, 10],
- }
- )
- snow_df = pd.DataFrame(native_df)
-
- eval_snowpark_pandas_result(
- snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)()
- )
-
-
-def test_timedelta_groupby_agg():
- native_df = native_pd.DataFrame(
- {
- "A": native_pd.to_timedelta(
- ["1 days 06:05:01.00003", "16us", "nan", "16us"]
- ),
- "B": [8, 8, 12, 10],
- "C": [True, False, False, True],
- }
+class TestTimedelta:
+ @sql_count_checker(query_count=1)
+ @pytest.mark.parametrize(
+ "method",
+ [
+ "count",
+ "mean",
+ "min",
+ "max",
+ "idxmax",
+ "idxmin",
+ "sum",
+ "median",
+ "std",
+ "nunique",
+ ],
)
- snow_df = pd.DataFrame(native_df)
- with SqlCounter(query_count=1):
+ @pytest.mark.parametrize("by", ["A", "B"])
+ def test_aggregation_methods(self, method, by):
eval_snowpark_pandas_result(
- snow_df,
- native_df,
- lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}),
+ *create_test_dfs(
+ {
+ "A": native_pd.to_timedelta(
+ ["1 days 06:05:01.00003", "16us", "nan", "16us"]
+ ),
+ "B": [8, 8, 12, 10],
+ }
+ ),
+ lambda df: getattr(df.groupby(by), method)(),
)
- with SqlCounter(query_count=1):
- eval_snowpark_pandas_result(
- snow_df,
- native_df,
+
+ @sql_count_checker(query_count=1)
+ @pytest.mark.parametrize(
+ "operation",
+ [
+ lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}),
lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}),
+ lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}),
+ lambda df: df.groupby("B").agg(["mean", "std"]),
+ lambda df: df.groupby("B").agg({"A": ["count", np.sum]}),
+ lambda df: df.groupby("B").agg({"A": "sum"}),
+ ],
+ )
+ def test_agg(self, operation):
+ eval_snowpark_pandas_result(
+ *create_test_dfs(
+ native_pd.DataFrame(
+ {
+ "A": native_pd.to_timedelta(
+ ["1 days 06:05:01.00003", "16us", "nan", "16us"]
+ ),
+ "B": [8, 8, 12, 10],
+ "C": [True, False, False, True],
+ }
+ )
+ ),
+ operation,
)
- with SqlCounter(query_count=1):
+
+ @sql_count_checker(query_count=1)
+ def test_groupby_timedelta_var(self):
+ """
+ Test that we can group by a timedelta column and take var() of an integer column.
+
+ Note that we can't take the groupby().var() of the timedelta column because
+ var() is not defined for timedelta, in pandas or in Snowpark pandas.
+ """
eval_snowpark_pandas_result(
- snow_df,
- native_df,
- lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}),
+ *create_test_dfs(
+ {
+ "A": native_pd.to_timedelta(
+ ["1 days 06:05:01.00003", "16us", "nan", "16us"]
+ ),
+ "B": [8, 8, 12, 10],
+ }
+ ),
+ lambda df: df.groupby("A").var(),
)
diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py
index 5da35806dd1..5e04d5a6fc2 100644
--- a/tests/integ/modin/groupby/test_groupby_first_last.py
+++ b/tests/integ/modin/groupby/test_groupby_first_last.py
@@ -46,6 +46,17 @@
[np.nan],
]
),
+ "col11_timedelta": [
+ pd.Timedelta("1 days"),
+ None,
+ pd.Timedelta("2 days"),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ ],
}
diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py
index a009e1089b0..0c9c056c2a7 100644
--- a/tests/integ/modin/groupby/test_groupby_negative.py
+++ b/tests/integ/modin/groupby/test_groupby_negative.py
@@ -18,6 +18,7 @@
MAP_DATA_AND_TYPE,
MIXED_NUMERIC_STR_DATA_AND_TYPE,
TIMESTAMP_DATA_AND_TYPE,
+ create_test_dfs,
eval_snowpark_pandas_result,
)
@@ -559,20 +560,12 @@ def test_groupby_agg_invalid_min_count(
@sql_count_checker(query_count=0)
-def test_groupby_var_no_support_for_timedelta():
- native_df = native_pd.DataFrame(
- {
- "A": native_pd.to_timedelta(
- ["1 days 06:05:01.00003", "15.5us", "nan", "16us"]
- ),
- "B": [8, 8, 12, 10],
- }
- )
- snow_df = pd.DataFrame(native_df)
- with pytest.raises(
- NotImplementedError,
- match=re.escape(
- "SnowflakeQueryCompiler::groupby_agg is not yet implemented for Timedelta Type"
+def test_timedelta_var_invalid():
+ eval_snowpark_pandas_result(
+ *create_test_dfs(
+ [["key0", pd.Timedelta(1)]],
),
- ):
- snow_df.groupby("B").var()
+ lambda df: df.groupby(0).var(),
+ expect_exception=True,
+ expect_exception_type=TypeError,
+ )
diff --git a/tests/integ/modin/groupby/test_quantile.py b/tests/integ/modin/groupby/test_quantile.py
index b14299fee63..940d366a7e2 100644
--- a/tests/integ/modin/groupby/test_quantile.py
+++ b/tests/integ/modin/groupby/test_quantile.py
@@ -64,6 +64,14 @@
# ),
# All NA
([np.nan] * 5, [np.nan] * 5),
+ pytest.param(
+ pd.timedelta_range(
+ "1 days",
+ "5 days",
+ ),
+ pd.timedelta_range("1 second", "5 second"),
+ id="timedelta",
+ ),
],
)
@pytest.mark.parametrize("q", [0, 0.5, 1])
diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py
index 3c6362dd83c..26afd232c4f 100644
--- a/tests/integ/modin/index/conftest.py
+++ b/tests/integ/modin/index/conftest.py
@@ -33,11 +33,11 @@
native_pd.Index(["a", "b", "c", "d"]),
native_pd.DatetimeIndex(
["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"],
- tz="America/Los_Angeles",
+ tz="UTC-08:00",
),
native_pd.DatetimeIndex(
["2020-01-01 10:00:00+05:00", "2020-02-01 11:00:00+05:00"],
- tz="America/Los_Angeles",
+ tz="UTC",
),
native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]),
native_pd.TimedeltaIndex(["0 days", "1 days", "3 days"]),
@@ -55,11 +55,11 @@
native_pd.Index(["a", "b", 1, 2, None, "a", 2], name="mixed index"),
native_pd.DatetimeIndex(
["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"],
- tz="America/Los_Angeles",
+ tz="UTC",
),
native_pd.DatetimeIndex(
["2020-01-01 10:00:00+00:00", "2020-01-01 10:00:00+00:00"],
- tz="America/Los_Angeles",
+ tz="UTC-08:00",
),
]
@@ -79,4 +79,5 @@
tz="America/Los_Angeles",
),
native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]),
+ native_pd.TimedeltaIndex(["4 days", None, "-1 days", "5 days"]),
]
diff --git a/tests/integ/modin/index/test_all_any.py b/tests/integ/modin/index/test_all_any.py
index 267e7929ea1..499be6f03dc 100644
--- a/tests/integ/modin/index/test_all_any.py
+++ b/tests/integ/modin/index/test_all_any.py
@@ -25,6 +25,9 @@
native_pd.Index(["a", "b", "c", "d"]),
native_pd.Index([5, None, 7]),
native_pd.Index([], dtype="object"),
+ native_pd.Index([pd.Timedelta(0), None]),
+ native_pd.Index([pd.Timedelta(0)]),
+ native_pd.Index([pd.Timedelta(0), pd.Timedelta(1)]),
]
NATIVE_INDEX_EMPTY_DATA = [
diff --git a/tests/integ/modin/index/test_argmax_argmin.py b/tests/integ/modin/index/test_argmax_argmin.py
index 6d446a0a66a..7d42f3b88c9 100644
--- a/tests/integ/modin/index/test_argmax_argmin.py
+++ b/tests/integ/modin/index/test_argmax_argmin.py
@@ -18,6 +18,18 @@
native_pd.Index([4, None, 1, 3, 4, 1]),
native_pd.Index([4, None, 1, 3, 4, 1], name="some name"),
native_pd.Index([1, 10, 4, 3, 4]),
+ pytest.param(
+ native_pd.Index(
+ [
+ pd.Timedelta(1),
+ pd.Timedelta(10),
+ pd.Timedelta(4),
+ pd.Timedelta(3),
+ pd.Timedelta(4),
+ ]
+ ),
+ id="timedelta",
+ ),
],
)
@pytest.mark.parametrize("func", ["argmax", "argmin"])
diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py
index 56fd40a6cb3..98d1a041c3b 100644
--- a/tests/integ/modin/index/test_datetime_index_methods.py
+++ b/tests/integ/modin/index/test_datetime_index_methods.py
@@ -4,8 +4,10 @@
import re
import modin.pandas as pd
+import numpy as np
import pandas as native_pd
import pytest
+import pytz
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
@@ -16,6 +18,46 @@
eval_snowpark_pandas_result,
)
+timezones = pytest.mark.parametrize(
+ "tz",
+ [
+ None,
+ # Use a subset of pytz.common_timezones containing a few timezones in each
+ *[
+ param_for_one_tz
+ for tz in [
+ "Africa/Abidjan",
+ "Africa/Timbuktu",
+ "America/Adak",
+ "America/Yellowknife",
+ "Antarctica/Casey",
+ "Asia/Dhaka",
+ "Asia/Manila",
+ "Asia/Shanghai",
+ "Atlantic/Stanley",
+ "Australia/Sydney",
+ "Canada/Pacific",
+ "Europe/Chisinau",
+ "Europe/Luxembourg",
+ "Indian/Christmas",
+ "Pacific/Chatham",
+ "Pacific/Wake",
+ "US/Arizona",
+ "US/Central",
+ "US/Eastern",
+ "US/Hawaii",
+ "US/Mountain",
+ "US/Pacific",
+ "UTC",
+ ]
+ for param_for_one_tz in (
+ pytz.timezone(tz),
+ tz,
+ )
+ ],
+ ],
+)
+
@sql_count_checker(query_count=0)
def test_datetime_index_construction():
@@ -100,13 +142,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame({"A": [1]}, index=native_idx1)
snow_idx = df.index
- assert_frame_equal(snow_idx._parent, df)
+ assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)
# Series case.
s = pd.Series([1, 2], index=native_idx2, name="zyx")
snow_idx = s.index
- assert_series_equal(snow_idx._parent, s)
+ assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)
@@ -232,6 +274,76 @@ def test_normalize():
)
+@sql_count_checker(query_count=1, join_count=1)
+@timezones
+def test_tz_convert(tz):
+ native_index = native_pd.date_range(
+ start="2021-01-01", periods=5, freq="7h", tz="US/Eastern"
+ )
+ native_index = native_index.append(
+ native_pd.DatetimeIndex([pd.NaT], tz="US/Eastern")
+ )
+ snow_index = pd.DatetimeIndex(native_index)
+
+ # Using eval_snowpark_pandas_result() was not possible because currently
+ # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype
+ # even if the data contains a timezone.
+ assert snow_index.tz_convert(tz).equals(
+ pd.DatetimeIndex(native_index.tz_convert(tz))
+ )
+
+
+@sql_count_checker(query_count=1, join_count=1)
+@timezones
+def test_tz_localize(tz):
+ native_index = native_pd.DatetimeIndex(
+ [
+ "2014-04-04 23:56:01.000000001",
+ "2014-07-18 21:24:02.000000002",
+ "2015-11-22 22:14:03.000000003",
+ "2015-11-23 20:12:04.1234567890",
+ pd.NaT,
+ ],
+ )
+ snow_index = pd.DatetimeIndex(native_index)
+
+ # Using eval_snowpark_pandas_result() was not possible because currently
+ # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype
+ # even if the data contains a timezone.
+ assert snow_index.tz_localize(tz).equals(
+ pd.DatetimeIndex(native_index.tz_localize(tz))
+ )
+
+
+@pytest.mark.parametrize(
+ "ambiguous, nonexistent",
+ [
+ ("infer", "raise"),
+ ("NaT", "raise"),
+ (np.array([True, True, False]), "raise"),
+ ("raise", "shift_forward"),
+ ("raise", "shift_backward"),
+ ("raise", "NaT"),
+ ("raise", pd.Timedelta("1h")),
+ ("infer", "shift_forward"),
+ ],
+)
+@sql_count_checker(query_count=0)
+def test_tz_localize_negative(ambiguous, nonexistent):
+ native_index = native_pd.DatetimeIndex(
+ [
+ "2014-04-04 23:56:01.000000001",
+ "2014-07-18 21:24:02.000000002",
+ "2015-11-22 22:14:03.000000003",
+ "2015-11-23 20:12:04.1234567890",
+ pd.NaT,
+ ],
+ )
+ snow_index = pd.DatetimeIndex(native_index)
+ with pytest.raises(NotImplementedError):
+ snow_index.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent)
+
+
@pytest.mark.parametrize(
"datetime_index_value",
[
@@ -268,7 +380,12 @@ def test_floor_ceil_round(datetime_index_value, func, freq):
[
("1w", "raise", "raise"),
("1h", "infer", "raise"),
+ ("1h", "NaT", "raise"),
+ ("1h", np.array([True, True, False]), "raise"),
("1h", "raise", "shift_forward"),
+ ("1h", "raise", "shift_backward"),
+ ("1h", "raise", "NaT"),
+ ("1h", "raise", pd.Timedelta("1h")),
("1w", "infer", "shift_forward"),
],
)
diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py
index 8d0434915ac..6b33eb89889 100644
--- a/tests/integ/modin/index/test_index_methods.py
+++ b/tests/integ/modin/index/test_index_methods.py
@@ -393,13 +393,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1)
snow_idx = df.index
- assert_frame_equal(snow_idx._parent, df)
+ assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)
# Series case.
s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx")
snow_idx = s.index
- assert_series_equal(snow_idx._parent, s)
+ assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)
diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py
index b916110f386..f915598c5f6 100644
--- a/tests/integ/modin/index/test_name.py
+++ b/tests/integ/modin/index/test_name.py
@@ -351,3 +351,69 @@ def test_index_names_with_lazy_index():
),
inplace=True,
)
+
+
+@sql_count_checker(query_count=1)
+def test_index_names_replace_behavior():
+ """
+ Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified.
+ """
+ data = {
+ "A": [0, 1, 2, 3, 4, 4],
+ "B": ["a", "b", "c", "d", "e", "f"],
+ }
+ idx = [1, 2, 3, 4, 5, 6]
+ native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
+ snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))
+
+ # Get a reference to the index of the DataFrames.
+ snow_index = snow_df.index
+ native_index = native_df.index
+
+ # Change the names.
+ snow_index.name = "test2"
+ native_index.name = "test2"
+
+ # Compare the names.
+ assert snow_index.name == native_index.name == "test2"
+ assert snow_df.index.name == native_df.index.name == "test2"
+
+ # Change the query compiler the DataFrame is referring to, change the names.
+ snow_df.dropna(inplace=True)
+ native_df.dropna(inplace=True)
+ snow_index.name = "test3"
+ native_index.name = "test3"
+
+ # Compare the names. Changing the index name should not change the DataFrame's index name.
+ assert snow_index.name == native_index.name == "test3"
+ assert snow_df.index.name == native_df.index.name == "test2"
+
+
+@sql_count_checker(query_count=1)
+def test_index_names_multiple_renames():
+ """
+ Check that the index name of a DataFrame can be renamed any number of times.
+ """
+ data = {
+ "A": [0, 1, 2, 3, 4, 4],
+ "B": ["a", "b", "c", "d", "e", "f"],
+ }
+ idx = [1, 2, 3, 4, 5, 6]
+ native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
+ snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))
+
+ # Get a reference to the index of the DataFrames.
+ snow_index = snow_df.index
+ native_index = native_df.index
+
+ # Change and compare the names.
+ snow_index.name = "test2"
+ native_index.name = "test2"
+ assert snow_index.name == native_index.name == "test2"
+ assert snow_df.index.name == native_df.index.name == "test2"
+
+ # Change the names again and compare.
+ snow_index.name = "test3"
+ native_index.name = "test3"
+ assert snow_index.name == native_index.name == "test3"
+ assert snow_df.index.name == native_df.index.name == "test3"
diff --git a/tests/integ/modin/index/test_timedelta_index_methods.py b/tests/integ/modin/index/test_timedelta_index_methods.py
index 25bef5364f2..c4d4a0b3a66 100644
--- a/tests/integ/modin/index/test_timedelta_index_methods.py
+++ b/tests/integ/modin/index/test_timedelta_index_methods.py
@@ -128,3 +128,29 @@ def test_timedelta_total_seconds():
native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA)
snow_index = pd.Index(native_index)
eval_snowpark_pandas_result(snow_index, native_index, lambda x: x.total_seconds())
+
+
+@pytest.mark.parametrize("skipna", [True, False])
+@pytest.mark.parametrize("data", [[1, 2, 3], [1, 2, 3, None], [None], []])
+@sql_count_checker(query_count=1)
+def test_timedelta_index_mean(skipna, data):
+ native_index = native_pd.TimedeltaIndex(data)
+ snow_index = pd.Index(native_index)
+ native_result = native_index.mean(skipna=skipna)
+ snow_result = snow_index.mean(skipna=skipna)
+ # Special check for NaN because Nan != Nan.
+ if pd.isna(native_result):
+ assert pd.isna(snow_result)
+ else:
+ assert snow_result == native_result
+
+
+@sql_count_checker(query_count=0)
+def test_timedelta_index_mean_invalid_axis():
+ native_index = native_pd.TimedeltaIndex([1, 2, 3])
+ snow_index = pd.Index(native_index)
+ with pytest.raises(IndexError, match="tuple index out of range"):
+ native_index.mean(axis=1)
+ # Snowpark pandas raises ValueError instead of IndexError.
+ with pytest.raises(ValueError, match="axis should be 0 for TimedeltaIndex.mean"):
+ snow_index.mean(axis=1).to_pandas()
diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py
index fa354fda1fc..c3e40828d94 100644
--- a/tests/integ/modin/series/test_aggregate.py
+++ b/tests/integ/modin/series/test_aggregate.py
@@ -1,6 +1,8 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
+import re
+
import modin.pandas as pd
import numpy as np
import pandas as native_pd
@@ -17,6 +19,7 @@
MAP_DATA_AND_TYPE,
MIXED_NUMERIC_STR_DATA_AND_TYPE,
TIMESTAMP_DATA_AND_TYPE,
+ assert_snowpark_pandas_equals_to_pandas_without_dtypecheck,
create_test_series,
eval_snowpark_pandas_result,
)
@@ -358,3 +361,67 @@ def test_2_tuple_named_agg_errors_for_series(native_series, agg_kwargs):
expect_exception_type=SpecificationError,
assert_exception_equal=True,
)
+
+
+class TestTimedelta:
+ """Test aggregating a timedelta series."""
+
+ @pytest.mark.parametrize(
+ "func, union_count, is_scalar",
+ [
+ pytest.param(*v, id=str(i))
+ for i, v in enumerate(
+ [
+ (lambda series: series.aggregate(["min"]), 0, False),
+ (lambda series: series.aggregate({"A": "max"}), 0, False),
+ # this works since even though we need to do concats, all the results are non-timdelta.
+ (lambda df: df.aggregate(["all", "any", "count"]), 2, False),
+ # note following aggregation requires transpose
+ (lambda df: df.aggregate(max), 0, True),
+ (lambda df: df.min(), 0, True),
+ (lambda df: df.max(), 0, True),
+ (lambda df: df.count(), 0, True),
+ (lambda df: df.sum(), 0, True),
+ (lambda df: df.mean(), 0, True),
+ (lambda df: df.median(), 0, True),
+ (lambda df: df.std(), 0, True),
+ (lambda df: df.quantile(), 0, True),
+ (lambda df: df.quantile([0.01, 0.99]), 0, False),
+ ]
+ )
+ ],
+ )
+ def test_supported(self, func, union_count, timedelta_native_df, is_scalar):
+ with SqlCounter(query_count=1, union_count=union_count):
+ eval_snowpark_pandas_result(
+ *create_test_series(timedelta_native_df["A"]),
+ func,
+ comparator=validate_scalar_result
+ if is_scalar
+ else assert_snowpark_pandas_equals_to_pandas_without_dtypecheck,
+ )
+
+ @sql_count_checker(query_count=0)
+ def test_var_invalid(self, timedelta_native_df):
+ eval_snowpark_pandas_result(
+ *create_test_series(timedelta_native_df["A"]),
+ lambda series: series.var(),
+ expect_exception=True,
+ expect_exception_type=TypeError,
+ assert_exception_equal=False,
+ expect_exception_match=re.escape(
+ "timedelta64 type does not support var operations"
+ ),
+ )
+
+ @sql_count_checker(query_count=0)
+ @pytest.mark.xfail(
+ strict=True,
+ raises=NotImplementedError,
+ reason="requires concat(), which we cannot do with Timedelta.",
+ )
+ def test_unsupported_due_to_concat(self, timedelta_native_df):
+ eval_snowpark_pandas_result(
+ *create_test_series(timedelta_native_df["A"]),
+ lambda df: df.agg(["count", "max"]),
+ )
diff --git a/tests/integ/modin/series/test_argmax_argmin.py b/tests/integ/modin/series/test_argmax_argmin.py
index 607b36a27f3..e212e3ba2dd 100644
--- a/tests/integ/modin/series/test_argmax_argmin.py
+++ b/tests/integ/modin/series/test_argmax_argmin.py
@@ -18,6 +18,11 @@
([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]),
([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]),
([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]),
+ pytest.param(
+ [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)],
+ ["A", "B", "C", "D", "E"],
+ id="timedelta",
+ ),
],
)
@pytest.mark.parametrize("func", ["argmax", "argmin"])
diff --git a/tests/integ/modin/series/test_astype.py b/tests/integ/modin/series/test_astype.py
index 5bbce79b01b..030416d65c5 100644
--- a/tests/integ/modin/series/test_astype.py
+++ b/tests/integ/modin/series/test_astype.py
@@ -173,6 +173,11 @@ def test_astype_basic(from_dtype, to_dtype):
)
def test_astype_to_DatetimeTZDtype(from_dtype, to_tz):
to_dtype = f"datetime64[ns, {to_tz}]"
+ offset_map = {
+ "UTC": "UTC",
+ "Asia/Tokyo": "UTC+09:00",
+ "America/Los_Angeles": "UTC-08:00",
+ }
seed = (
[True, False, False, True]
# if isinstance(from_dtype, BooleanDtype)
@@ -189,23 +194,22 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz):
native_pd.Series(seed, dtype=from_dtype).astype(to_dtype)
elif isinstance(from_dtype, StringDtype) or from_dtype is str:
# Snowpark pandas use Snowflake auto format detection and the behavior can be different from native pandas
- # to_pandas always convert timezone to the local timezone today, i.e., "America/Los_angeles"
with SqlCounter(query_count=1):
assert_snowpark_pandas_equal_to_pandas(
pd.Series(seed, dtype=from_dtype).astype(to_dtype),
native_pd.Series(
[
native_pd.Timestamp("1970-01-01 00:00:00", tz="UTC").tz_convert(
- "America/Los_Angeles"
+ offset_map[to_tz]
),
native_pd.Timestamp("1970-01-01 00:00:01", tz="UTC").tz_convert(
- "America/Los_Angeles"
+ offset_map[to_tz]
),
native_pd.Timestamp("1970-01-01 00:00:02", tz="UTC").tz_convert(
- "America/Los_Angeles"
+ offset_map[to_tz]
),
native_pd.Timestamp("1970-01-01 00:00:03", tz="UTC").tz_convert(
- "America/Los_Angeles"
+ offset_map[to_tz]
),
]
),
@@ -251,15 +255,15 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz):
):
native_pd.Series(seed, dtype=from_dtype).astype(to_dtype)
expected_to_pandas = (
- native_pd.Series(seed, dtype=from_dtype).dt.tz_localize("UTC")
- # Snowpark pandas to_pandas() will convert timestamp_tz to default local timezone
- .dt.tz_convert("America/Los_Angeles")
+ native_pd.Series(seed, dtype=from_dtype)
+ .dt.tz_localize("UTC")
+ .dt.tz_convert(offset_map[to_tz])
)
else:
expected_to_pandas = (
- native_pd.Series(seed, dtype=from_dtype).astype(to_dtype)
- # Snowpark pandas to_pandas() will convert timestamp_tz to default local timezone
- .dt.tz_convert("America/Los_Angeles")
+ native_pd.Series(seed, dtype=from_dtype)
+ .astype(to_dtype)
+ .dt.tz_convert(offset_map[to_tz])
)
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(
s,
@@ -392,11 +396,7 @@ def test_python_datetime_astype_DatetimeTZDtype(seed):
with SqlCounter(query_count=1):
snow = s.astype(to_dtype)
assert snow.dtype == np.dtype(" native_pd.Series:
return native_pd.Series(
@@ -140,7 +183,12 @@ def test_floor_ceil_round(datetime_index_value, func, freq):
[
("1w", "raise", "raise"),
("1h", "infer", "raise"),
+ ("1h", "NaT", "raise"),
+ ("1h", np.array([True, True, False]), "raise"),
("1h", "raise", "shift_forward"),
+ ("1h", "raise", "shift_backward"),
+ ("1h", "raise", "NaT"),
+ ("1h", "raise", pd.Timedelta("1h")),
("1w", "infer", "shift_forward"),
],
)
@@ -174,6 +222,79 @@ def test_normalize():
)
+@sql_count_checker(query_count=1)
+@timezones
+def test_tz_convert(tz):
+ datetime_index = native_pd.DatetimeIndex(
+ [
+ "2014-04-04 23:56:01.000000001",
+ "2014-07-18 21:24:02.000000002",
+ "2015-11-22 22:14:03.000000003",
+ "2015-11-23 20:12:04.1234567890",
+ pd.NaT,
+ ],
+ tz="US/Eastern",
+ )
+ native_ser = native_pd.Series(datetime_index)
+ snow_ser = pd.Series(native_ser)
+ eval_snowpark_pandas_result(
+ snow_ser,
+ native_ser,
+ lambda s: s.dt.tz_convert(tz),
+ )
+
+
+@sql_count_checker(query_count=1)
+@timezones
+def test_tz_localize(tz):
+ datetime_index = native_pd.DatetimeIndex(
+ [
+ "2014-04-04 23:56:01.000000001",
+ "2014-07-18 21:24:02.000000002",
+ "2015-11-22 22:14:03.000000003",
+ "2015-11-23 20:12:04.1234567890",
+ pd.NaT,
+ ],
+ )
+ native_ser = native_pd.Series(datetime_index)
+ snow_ser = pd.Series(native_ser)
+ eval_snowpark_pandas_result(
+ snow_ser,
+ native_ser,
+ lambda s: s.dt.tz_localize(tz),
+ )
+
+
+@pytest.mark.parametrize(
+ "ambiguous, nonexistent",
+ [
+ ("infer", "raise"),
+ ("NaT", "raise"),
+ (np.array([True, True, False]), "raise"),
+ ("raise", "shift_forward"),
+ ("raise", "shift_backward"),
+ ("raise", "NaT"),
+ ("raise", pd.Timedelta("1h")),
+ ("infer", "shift_forward"),
+ ],
+)
+@sql_count_checker(query_count=0)
+def test_tz_localize_negative(ambiguous, nonexistent):
+ datetime_index = native_pd.DatetimeIndex(
+ [
+ "2014-04-04 23:56:01.000000001",
+ "2014-07-18 21:24:02.000000002",
+ "2015-11-22 22:14:03.000000003",
+ "2015-11-23 20:12:04.1234567890",
+ pd.NaT,
+ ],
+ )
+ native_ser = native_pd.Series(datetime_index)
+ snow_ser = pd.Series(native_ser)
+ with pytest.raises(NotImplementedError):
+ snow_ser.dt.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent)
+
+
@pytest.mark.parametrize("name", [None, "hello"])
def test_isocalendar(name):
with SqlCounter(query_count=1):
diff --git a/tests/integ/modin/series/test_first_last_valid_index.py b/tests/integ/modin/series/test_first_last_valid_index.py
index 1e8d052e10f..1930bdf1088 100644
--- a/tests/integ/modin/series/test_first_last_valid_index.py
+++ b/tests/integ/modin/series/test_first_last_valid_index.py
@@ -22,6 +22,10 @@
native_pd.Series([5, 6, 7, 8], index=["i", "am", "iron", "man"]),
native_pd.Series([None, None, 2], index=[None, 1, 2]),
native_pd.Series([None, None, 2], index=[None, None, None]),
+ pytest.param(
+ native_pd.Series([None, None, pd.Timedelta(2)], index=[None, 1, 2]),
+ id="timedelta",
+ ),
],
)
def test_first_and_last_valid_index_series(native_series):
diff --git a/tests/integ/modin/series/test_idxmax_idxmin.py b/tests/integ/modin/series/test_idxmax_idxmin.py
index ea536240a42..e8e66a30f61 100644
--- a/tests/integ/modin/series/test_idxmax_idxmin.py
+++ b/tests/integ/modin/series/test_idxmax_idxmin.py
@@ -17,6 +17,11 @@
([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]),
([1, None, 4, 3, 4], [None, "B", "C", "D", "E"]),
([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]),
+ pytest.param(
+ [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)],
+ ["A", "B", "C", "D", "E"],
+ id="timedelta",
+ ),
],
)
@pytest.mark.parametrize("func", ["idxmax", "idxmin"])
diff --git a/tests/integ/modin/series/test_nunique.py b/tests/integ/modin/series/test_nunique.py
index bb20e9e4a53..3856dbc516a 100644
--- a/tests/integ/modin/series/test_nunique.py
+++ b/tests/integ/modin/series/test_nunique.py
@@ -6,6 +6,7 @@
import numpy as np
import pandas as native_pd
import pytest
+from pytest import param
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import sql_count_checker
@@ -32,6 +33,20 @@
[True, None, False, True, None],
[1.1, "a", None] * 4,
[native_pd.to_datetime("2023-12-01"), native_pd.to_datetime("1999-09-09")] * 2,
+ param(
+ [
+ native_pd.Timedelta(1),
+ native_pd.Timedelta(1),
+ native_pd.Timedelta(2),
+ None,
+ None,
+ ],
+ id="timedelta_with_nulls",
+ ),
+ param(
+ [native_pd.Timedelta(1), native_pd.Timedelta(1), native_pd.Timedelta(2)],
+ id="timedelta_without_nulls",
+ ),
],
)
@pytest.mark.parametrize("dropna", [True, False])
diff --git a/tests/integ/modin/test_classes.py b/tests/integ/modin/test_classes.py
index c92bb85c531..6e6c2eda8eb 100644
--- a/tests/integ/modin/test_classes.py
+++ b/tests/integ/modin/test_classes.py
@@ -34,14 +34,14 @@ def test_class_names_constructors():
expect_type_check(
df,
pd.DataFrame,
- "snowflake.snowpark.modin.pandas.dataframe.DataFrame",
+ "modin.pandas.dataframe.DataFrame",
)
s = pd.Series(index=[1, 2, 3], data=[3, 2, 1])
expect_type_check(
s,
pd.Series,
- "snowflake.snowpark.modin.pandas.series.Series",
+ "modin.pandas.series.Series",
)
@@ -63,7 +63,7 @@ def test_op():
expect_type_check(
df,
pd.DataFrame,
- "snowflake.snowpark.modin.pandas.dataframe.DataFrame",
+ "modin.pandas.dataframe.DataFrame",
)
@@ -77,7 +77,7 @@ def test_native_conversion():
expect_type_check(
df,
pd.DataFrame,
- "snowflake.snowpark.modin.pandas.dataframe.DataFrame",
+ "modin.pandas.dataframe.DataFrame",
)
# Snowpark pandas -> native pandas
diff --git a/tests/integ/modin/test_dtype_mapping.py b/tests/integ/modin/test_dtype_mapping.py
index 868a37ff22d..2e474c2aec4 100644
--- a/tests/integ/modin/test_dtype_mapping.py
+++ b/tests/integ/modin/test_dtype_mapping.py
@@ -281,15 +281,11 @@
"timestamp_tz timestamp_tz",
"values ('2023-01-01 00:00:01.001 +0000'), ('2023-12-31 23:59:59.999 +1000')", # timestamp_tz only supports tz offset
dtype(" from_pandas => TIMESTAMP_TZ(any_tz) => to_pandas => DatetimeTZDtype(session_tz)
- #
- # Note that python connector will convert any TIMESTAMP_TZ to DatetimeTZDtype with the current session/statement
- # timezone, e.g., 1969-12-31 19:00:00 -0500 will be converted to 1970-00-01 00:00:00 in UTC if the session/statement
- # parameter TIMEZONE = 'UTC'
- # TODO: SNOW-871210 no need session parameter change once the bug is fixed
- try:
- session.sql(f"alter session set timezone = '{timezone}'").collect()
-
- def get_series_with_tz(tz):
- return (
- native_pd.Series([1] * 3)
- .astype("int64")
- .astype(f"datetime64[ns, {tz}]")
- )
+@sql_count_checker(query_count=1)
+def test_from_to_pandas_datetime64_timezone_support():
+ def get_series_with_tz(tz):
+ return native_pd.Series([1] * 3).astype("int64").astype(f"datetime64[ns, {tz}]")
- # same timestamps representing in different time zone
- test_data_columns = {
- "utc": get_series_with_tz("UTC"),
- "pacific": get_series_with_tz("US/Pacific"),
- "tokyo": get_series_with_tz("Asia/Tokyo"),
- }
+ # same timestamps representing in different time zone
+ test_data_columns = {
+ "utc": get_series_with_tz("UTC"),
+ "pacific": get_series_with_tz("US/Pacific"),
+ "tokyo": get_series_with_tz("Asia/Tokyo"),
+ }
- # expected to_pandas dataframe's timezone is controlled by session/statement parameter TIMEZONE
- expected_to_pandas = native_pd.DataFrame(
- {
- series: test_data_columns[series].dt.tz_convert(timezone)
- for series in test_data_columns
- }
- )
- assert_snowpark_pandas_equal_to_pandas(
- pd.DataFrame(test_data_columns),
- expected_to_pandas,
- # configure different timezones to to_pandas and verify the timestamps are converted correctly
- statement_params={"timezone": timezone},
- )
- finally:
- # TODO: SNOW-871210 no need session parameter change once the bug is fixed
- session.sql("alter session unset timezone").collect()
+ expected_data_columns = {
+ "utc": get_series_with_tz("UTC"),
+ "pacific": get_series_with_tz("UTC-08:00"),
+ "tokyo": get_series_with_tz("UTC+09:00"),
+ }
+ # expected to_pandas dataframe's timezone is controlled by session/statement parameter TIMEZONE
+ expected_to_pandas = native_pd.DataFrame(expected_data_columns)
+ assert_snowpark_pandas_equal_to_pandas(
+ pd.DataFrame(test_data_columns),
+ expected_to_pandas,
+ )
-@pytest.mark.parametrize("timezone", ["UTC", "US/Pacific", "US/Eastern"])
-@sql_count_checker(query_count=3)
-def test_from_to_pandas_datetime64_multi_timezone_current_behavior(session, timezone):
- try:
- # TODO: SNOW-871210 no need session parameter change once the bug is fixed
- session.sql(f"alter session set timezone = '{timezone}'").collect()
-
- # This test also verifies the current behaviors of to_pandas() for datetime with no tz, same tz, or multi tz:
- # no tz => TIMESTAMP_NTZ
- # same tz => TIMESTAMP_TZ
- # multi tz => TIMESTAMP_NTZ
- multi_tz_data = ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-07:00"]
- test_data_columns = {
- "no tz": native_pd.to_datetime(
- native_pd.Series(["2019-05-21 12:00:00", "2019-05-21 12:15:00"])
- ), # dtype = datetime64[ns]
- "same tz": native_pd.to_datetime(
- native_pd.Series(
- ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-06:00"]
- )
- ), # dtype = datetime64[ns, tz]
- "multi tz": native_pd.to_datetime(
- native_pd.Series(multi_tz_data)
- ), # dtype = object and value type is Python datetime
- }
+@sql_count_checker(query_count=1)
+def test_from_to_pandas_datetime64_multi_timezone_current_behavior():
+ # This test also verifies the current behaviors of to_pandas() for datetime with no tz, same tz, or multi tz:
+ # no tz => TIMESTAMP_NTZ
+ # same tz => TIMESTAMP_TZ
+ # multi tz => TIMESTAMP_TZ
+ multi_tz_data = ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-07:00"]
+ test_data_columns = {
+ "no tz": native_pd.to_datetime(
+ native_pd.Series(["2019-05-21 12:00:00", "2019-05-21 12:15:00"])
+ ), # dtype = datetime64[ns]
+ "same tz": native_pd.to_datetime(
+ native_pd.Series(["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-06:00"])
+ ), # dtype = datetime64[ns, tz]
+ "multi tz": native_pd.to_datetime(
+ native_pd.Series(multi_tz_data)
+ ), # dtype = object and value type is Python datetime
+ }
- expected_to_pandas = native_pd.DataFrame(
- {
- "no tz": test_data_columns["no tz"], # dtype = datetime64[ns]
- "same tz": test_data_columns["same tz"].dt.tz_convert(
- timezone
- ), # dtype = datetime64[ns, tz]
- "multi tz": native_pd.Series(
- [
- native_pd.to_datetime(t).tz_convert(timezone)
- for t in multi_tz_data
- ]
- ),
- }
- )
+ expected_to_pandas = native_pd.DataFrame(test_data_columns)
- test_df = native_pd.DataFrame(test_data_columns)
- # dtype checks for each series
- no_tz_dtype = test_df.dtypes["no tz"]
- assert is_datetime64_any_dtype(no_tz_dtype) and not isinstance(
- no_tz_dtype, DatetimeTZDtype
- )
- same_tz_dtype = test_df.dtypes["same tz"]
- assert is_datetime64_any_dtype(same_tz_dtype) and isinstance(
- same_tz_dtype, DatetimeTZDtype
- )
- multi_tz_dtype = test_df.dtypes["multi tz"]
- assert (
- not is_datetime64_any_dtype(multi_tz_dtype)
- and not isinstance(multi_tz_dtype, DatetimeTZDtype)
- and str(multi_tz_dtype) == "object"
- )
- # sample value
- assert isinstance(test_df["multi tz"][0], datetime.datetime)
- assert test_df["multi tz"][0].tzinfo is not None
- assert_snowpark_pandas_equal_to_pandas(
- pd.DataFrame(test_df),
- expected_to_pandas,
- statement_params={"timezone": timezone},
- )
- finally:
- # TODO: SNOW-871210 no need session parameter change once the bug is fixed
- session.sql("alter session unset timezone").collect()
+ test_df = native_pd.DataFrame(test_data_columns)
+ # dtype checks for each series
+ no_tz_dtype = test_df.dtypes["no tz"]
+ assert is_datetime64_any_dtype(no_tz_dtype) and not isinstance(
+ no_tz_dtype, DatetimeTZDtype
+ )
+ same_tz_dtype = test_df.dtypes["same tz"]
+ assert is_datetime64_any_dtype(same_tz_dtype) and isinstance(
+ same_tz_dtype, DatetimeTZDtype
+ )
+ multi_tz_dtype = test_df.dtypes["multi tz"]
+ assert (
+ not is_datetime64_any_dtype(multi_tz_dtype)
+ and not isinstance(multi_tz_dtype, DatetimeTZDtype)
+ and str(multi_tz_dtype) == "object"
+ )
+ # sample value
+ assert isinstance(test_df["multi tz"][0], datetime.datetime)
+ assert test_df["multi tz"][0].tzinfo is not None
+ assert_snowpark_pandas_equal_to_pandas(
+ pd.DataFrame(test_df),
+ expected_to_pandas,
+ )
@sql_count_checker(query_count=1)
diff --git a/tests/integ/modin/test_merge_asof.py b/tests/integ/modin/test_merge_asof.py
index 681d339da90..51dda7889e7 100644
--- a/tests/integ/modin/test_merge_asof.py
+++ b/tests/integ/modin/test_merge_asof.py
@@ -105,6 +105,7 @@ def left_right_timestamp_data():
pd.Timestamp("2016-05-25 13:30:00.072"),
pd.Timestamp("2016-05-25 13:30:00.075"),
],
+ "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
"bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01],
"ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03],
}
@@ -118,6 +119,7 @@ def left_right_timestamp_data():
pd.Timestamp("2016-05-25 13:30:00.048"),
pd.Timestamp("2016-05-25 13:30:00.048"),
],
+ "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"],
"price": [51.95, 51.95, 720.77, 720.92, 98.0],
"quantity": [75, 155, 100, 100, 100],
}
@@ -229,14 +231,39 @@ def test_merge_asof_left_right_on(
assert_snowpark_pandas_equal_to_pandas(snow_output, native_output)
+@pytest.mark.parametrize("by", ["ticker", ["ticker"]])
@sql_count_checker(query_count=1, join_count=1)
-def test_merge_asof_timestamps(left_right_timestamp_data):
+def test_merge_asof_by(left_right_timestamp_data, by):
left_native_df, right_native_df = left_right_timestamp_data
left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame(
right_native_df
)
- native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time")
- snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time")
+ native_output = native_pd.merge_asof(
+ left_native_df, right_native_df, on="time", by=by
+ )
+ snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by=by)
+ assert_snowpark_pandas_equal_to_pandas(snow_output, native_output)
+
+
+@pytest.mark.parametrize(
+ "left_by, right_by",
+ [
+ ("ticker", "ticker"),
+ (["ticker", "bid"], ["ticker", "price"]),
+ ],
+)
+@sql_count_checker(query_count=1, join_count=1)
+def test_merge_asof_left_right_by(left_right_timestamp_data, left_by, right_by):
+ left_native_df, right_native_df = left_right_timestamp_data
+ left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame(
+ right_native_df
+ )
+ native_output = native_pd.merge_asof(
+ left_native_df, right_native_df, on="time", left_by=left_by, right_by=right_by
+ )
+ snow_output = pd.merge_asof(
+ left_snow_df, right_snow_df, on="time", left_by=left_by, right_by=right_by
+ )
assert_snowpark_pandas_equal_to_pandas(snow_output, native_output)
@@ -248,8 +275,10 @@ def test_merge_asof_date(left_right_timestamp_data):
left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame(
right_native_df
)
- native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time")
- snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time")
+ native_output = native_pd.merge_asof(
+ left_native_df, right_native_df, on="time", by="ticker"
+ )
+ snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by="ticker")
assert_snowpark_pandas_equal_to_pandas(snow_output, native_output)
@@ -360,9 +389,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data):
with pytest.raises(
NotImplementedError,
match=(
- "Snowpark pandas merge_asof method does not currently support parameters "
- "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
- "'suffixes', or 'tolerance'"
+ "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'"
),
):
pd.merge_asof(
@@ -372,19 +399,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data):
NotImplementedError,
match=(
"Snowpark pandas merge_asof method does not currently support parameters "
- "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
- "'suffixes', or 'tolerance'"
- ),
- ):
- pd.merge_asof(
- left_snow_df, right_snow_df, on="time", left_by="price", right_by="quantity"
- )
- with pytest.raises(
- NotImplementedError,
- match=(
- "Snowpark pandas merge_asof method does not currently support parameters "
- "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
- "'suffixes', or 'tolerance'"
+ + "'left_index', 'right_index', 'suffixes', or 'tolerance'"
),
):
pd.merge_asof(left_snow_df, right_snow_df, left_index=True, right_index=True)
@@ -392,8 +407,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data):
NotImplementedError,
match=(
"Snowpark pandas merge_asof method does not currently support parameters "
- "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
- "'suffixes', or 'tolerance'"
+ + "'left_index', 'right_index', 'suffixes', or 'tolerance'"
),
):
pd.merge_asof(
@@ -406,8 +420,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data):
NotImplementedError,
match=(
"Snowpark pandas merge_asof method does not currently support parameters "
- "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
- "'suffixes', or 'tolerance'"
+ + "'left_index', 'right_index', 'suffixes', or 'tolerance'"
),
):
pd.merge_asof(
diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py
index ce9e1caf328..a36298af251 100644
--- a/tests/integ/modin/test_telemetry.py
+++ b/tests/integ/modin/test_telemetry.py
@@ -110,7 +110,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name):
df1_expected_api_calls = [
{"name": "TestClass.test_func"},
- {"name": "DataFrame.DataFrame.dropna", "argument": ["inplace"]},
+ {"name": "DataFrame.dropna", "argument": ["inplace"]},
]
assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls
@@ -121,7 +121,7 @@ def test_snowpark_pandas_telemetry_method_decorator(test_table_name):
assert df1._query_compiler.snowpark_pandas_api_calls == df1_expected_api_calls
df2_expected_api_calls = df1_expected_api_calls + [
{
- "name": "DataFrame.DataFrame.dropna",
+ "name": "DataFrame.dropna",
},
]
assert df2._query_compiler.snowpark_pandas_api_calls == df2_expected_api_calls
@@ -336,10 +336,7 @@ def test_telemetry_with_update_inplace():
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df.insert(1, "newcol", [99, 99, 90])
assert len(df._query_compiler.snowpark_pandas_api_calls) == 1
- assert (
- df._query_compiler.snowpark_pandas_api_calls[0]["name"]
- == "DataFrame.DataFrame.insert"
- )
+ assert df._query_compiler.snowpark_pandas_api_calls[0]["name"] == "DataFrame.insert"
@sql_count_checker(query_count=1)
@@ -403,8 +400,8 @@ def test_telemetry_getitem_setitem():
df["a"] = 0
df["b"] = 0
assert df._query_compiler.snowpark_pandas_api_calls == [
- {"name": "DataFrame.DataFrame.__setitem__"},
- {"name": "DataFrame.DataFrame.__setitem__"},
+ {"name": "DataFrame.__setitem__"},
+ {"name": "DataFrame.__setitem__"},
]
# Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction.
s._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch()
@@ -422,13 +419,17 @@ def test_telemetry_getitem_setitem():
@pytest.mark.parametrize(
- "name, method, expected_query_count",
+ "name, expected_func_name, method, expected_query_count",
[
- ["__repr__", lambda df: df.__repr__(), 1],
- ["__iter__", lambda df: df.__iter__(), 0],
+ # __repr__ is an extension method, so the class name is shown only once.
+ ["__repr__", "DataFrame.__repr__", lambda df: df.__repr__(), 1],
+ # __iter__ was defined on the DataFrame class, so it is shown twice.
+ ["__iter__", "DataFrame.DataFrame.__iter__", lambda df: df.__iter__(), 0],
],
)
-def test_telemetry_private_method(name, method, expected_query_count):
+def test_telemetry_private_method(
+ name, expected_func_name, method, expected_query_count
+):
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
# Clear connector telemetry client buffer to avoid flush triggered by the next API call, ensuring log extraction.
df._query_compiler._modin_frame.ordered_dataframe.session._conn._telemetry_client.telemetry.send_batch()
@@ -439,10 +440,10 @@ def test_telemetry_private_method(name, method, expected_query_count):
# the telemetry log from the connector to validate
data = _extract_snowpark_pandas_telemetry_log_data(
- expected_func_name=f"DataFrame.DataFrame.{name}",
+ expected_func_name=expected_func_name,
session=df._query_compiler._modin_frame.ordered_dataframe.session,
)
- assert data["api_calls"] == [{"name": f"DataFrame.DataFrame.{name}"}]
+ assert data["api_calls"] == [{"name": expected_func_name}]
@sql_count_checker(query_count=0)
diff --git a/tests/integ/modin/tools/test_to_datetime.py b/tests/integ/modin/tools/test_to_datetime.py
index 1ea3445d15a..df11e6afb80 100644
--- a/tests/integ/modin/tools/test_to_datetime.py
+++ b/tests/integ/modin/tools/test_to_datetime.py
@@ -565,7 +565,7 @@ def test_to_datetime_mixed_datetime_and_string(self):
assert_index_equal(res, expected)
# Set utc=True to make sure timezone aware in to_datetime
res = to_datetime(pd.Index(["2020-01-01 17:00:00 -0100", d2]), utc=True)
- expected = pd.DatetimeIndex([d1, d2])
+ expected = pd.DatetimeIndex([d1, d2], tz="UTC")
assert_index_equal(res, expected)
@pytest.mark.parametrize(
diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py
index 4c72df42bba..d28362374ce 100644
--- a/tests/integ/modin/types/test_timedelta.py
+++ b/tests/integ/modin/types/test_timedelta.py
@@ -2,10 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import datetime
+import warnings
import modin.pandas as pd
import pandas as native_pd
import pytest
+from pandas.errors import SettingWithCopyWarning
from tests.integ.modin.sql_counter import sql_count_checker
from tests.integ.modin.utils import (
@@ -107,3 +109,10 @@ def test_timedelta_not_supported():
match="SnowflakeQueryCompiler::groupby_groups is not yet implemented for Timedelta Type",
):
df.groupby("a").groups()
+
+
+@sql_count_checker(query_count=1)
+def test_aggregation_does_not_print_internal_warning_SNOW_1664064():
+ with warnings.catch_warnings():
+ warnings.simplefilter(category=SettingWithCopyWarning, action="error")
+ pd.Series(pd.Timedelta(1)).max()
diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py
index e42a504a976..bdd780ea69e 100644
--- a/tests/integ/test_large_query_breakdown.py
+++ b/tests/integ/test_large_query_breakdown.py
@@ -9,9 +9,13 @@
import pytest
from snowflake.snowpark._internal.analyzer import analyzer
-from snowflake.snowpark._internal.compiler import large_query_breakdown
from snowflake.snowpark.functions import col, lit, sum_distinct, when_matched
from snowflake.snowpark.row import Row
+from snowflake.snowpark.session import (
+ DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND,
+ DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND,
+ Session,
+)
from tests.utils import Utils
pytestmark = [
@@ -22,9 +26,6 @@
)
]
-DEFAULT_LOWER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND
-DEFAULT_UPPER_BOUND = large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND
-
@pytest.fixture(autouse=True)
def large_query_df(session):
@@ -50,20 +51,24 @@ def setup(session):
is_query_compilation_stage_enabled = session._query_compilation_stage_enabled
session._query_compilation_stage_enabled = True
session._large_query_breakdown_enabled = True
+ set_bounds(session, 300, 600)
yield
session._query_compilation_stage_enabled = is_query_compilation_stage_enabled
session._cte_optimization_enabled = cte_optimization_enabled
session._large_query_breakdown_enabled = large_query_breakdown_enabled
- reset_bounds()
+ reset_bounds(session)
-def set_bounds(lower_bound: int, upper_bound: int):
- large_query_breakdown.COMPLEXITY_SCORE_LOWER_BOUND = lower_bound
- large_query_breakdown.COMPLEXITY_SCORE_UPPER_BOUND = upper_bound
+def set_bounds(session: Session, lower_bound: int, upper_bound: int):
+ session._large_query_breakdown_complexity_bounds = (lower_bound, upper_bound)
-def reset_bounds():
- set_bounds(DEFAULT_LOWER_BOUND, DEFAULT_UPPER_BOUND)
+def reset_bounds(session: Session):
+ set_bounds(
+ session,
+ DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND,
+ DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND,
+ )
def check_result_with_and_without_breakdown(session, df):
@@ -82,8 +87,6 @@ def check_result_with_and_without_breakdown(session, df):
def test_no_valid_nodes_found(session, large_query_df, caplog):
"""Test large query breakdown works with default bounds"""
- set_bounds(300, 600)
-
base_df = session.sql("select 1 as A, 2 as B")
df1 = base_df.with_column("A", col("A") + lit(1))
df2 = base_df.with_column("B", col("B") + lit(1))
@@ -104,7 +107,6 @@ def test_no_valid_nodes_found(session, large_query_df, caplog):
def test_large_query_breakdown_with_cte_optimization(session):
"""Test large query breakdown works with cte optimized plan"""
- set_bounds(300, 600)
session._cte_optimization_enabled = True
df0 = session.sql("select 2 as b, 32 as c")
df1 = session.sql("select 1 as a, 2 as b").filter(col("a") == 1)
@@ -131,7 +133,6 @@ def test_large_query_breakdown_with_cte_optimization(session):
def test_save_as_table(session, large_query_df):
- set_bounds(300, 600)
table_name = Utils.random_table_name()
with session.query_history() as history:
large_query_df.write.save_as_table(table_name, mode="overwrite")
@@ -146,7 +147,6 @@ def test_save_as_table(session, large_query_df):
def test_update_delete_merge(session, large_query_df):
- set_bounds(300, 600)
session._large_query_breakdown_enabled = True
table_name = Utils.random_table_name()
df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"])
@@ -186,7 +186,6 @@ def test_update_delete_merge(session, large_query_df):
def test_copy_into_location(session, large_query_df):
- set_bounds(300, 600)
remote_file_path = f"{session.get_session_stage()}/df.parquet"
with session.query_history() as history:
large_query_df.write.copy_into_location(
@@ -204,7 +203,6 @@ def test_copy_into_location(session, large_query_df):
def test_pivot_unpivot(session):
- set_bounds(300, 600)
session.sql(
"""create or replace temp table monthly_sales(A int, B int, month text)
as select * from values
@@ -243,7 +241,6 @@ def test_pivot_unpivot(session):
def test_sort(session):
- set_bounds(300, 600)
base_df = session.sql("select 1 as A, 2 as B")
df1 = base_df.with_column("A", col("A") + lit(1))
df2 = base_df.with_column("B", col("B") + lit(1))
@@ -276,7 +273,6 @@ def test_sort(session):
def test_multiple_query_plan(session, large_query_df):
- set_bounds(300, 600)
original_threshold = analyzer.ARRAY_BIND_THRESHOLD
try:
analyzer.ARRAY_BIND_THRESHOLD = 2
@@ -314,7 +310,6 @@ def test_multiple_query_plan(session, large_query_df):
def test_optimization_skipped_with_transaction(session, large_query_df, caplog):
"""Test large query breakdown is skipped when transaction is enabled"""
- set_bounds(300, 600)
session.sql("begin").collect()
assert Utils.is_active_transaction(session)
with caplog.at_level(logging.DEBUG):
@@ -330,7 +325,6 @@ def test_optimization_skipped_with_transaction(session, large_query_df, caplog):
def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog):
"""Test large query breakdown is skipped plan is a view or dynamic table"""
- set_bounds(300, 600)
source_table = Utils.random_table_name()
table_name = Utils.random_table_name()
view_name = Utils.random_view_name()
@@ -360,7 +354,6 @@ def test_optimization_skipped_with_views_and_dynamic_tables(session, caplog):
def test_async_job_with_large_query_breakdown(session, large_query_df):
"""Test large query breakdown gives same result for async and non-async jobs"""
- set_bounds(300, 600)
job = large_query_df.collect(block=False)
result = job.result()
assert result == large_query_df.collect()
@@ -376,8 +369,6 @@ def test_async_job_with_large_query_breakdown(session, large_query_df):
def test_add_parent_plan_uuid_to_statement_params(session, large_query_df):
- set_bounds(300, 600)
-
with patch.object(
session._conn, "run_query", wraps=session._conn.run_query
) as patched_run_query:
@@ -400,7 +391,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df):
"""Test complexity bounds affect number of partitions.
Also test that when partitions are added, drop table queries are added.
"""
- set_bounds(300, 600)
+ set_bounds(session, 300, 600)
assert len(large_query_df.queries["queries"]) == 2
assert len(large_query_df.queries["post_actions"]) == 1
assert large_query_df.queries["queries"][0].startswith(
@@ -410,7 +401,7 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df):
"DROP TABLE If EXISTS"
)
- set_bounds(300, 412)
+ set_bounds(session, 300, 412)
assert len(large_query_df.queries["queries"]) == 3
assert len(large_query_df.queries["post_actions"]) == 2
assert large_query_df.queries["queries"][0].startswith(
@@ -426,11 +417,11 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df):
"DROP TABLE If EXISTS"
)
- set_bounds(0, 300)
+ set_bounds(session, 0, 300)
assert len(large_query_df.queries["queries"]) == 1
assert len(large_query_df.queries["post_actions"]) == 0
- reset_bounds()
+ reset_bounds(session)
assert len(large_query_df.queries["queries"]) == 1
assert len(large_query_df.queries["post_actions"]) == 0
diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py
index 0e8bb0d902d..81b852c46c1 100644
--- a/tests/integ/test_query_plan_analysis.py
+++ b/tests/integ/test_query_plan_analysis.py
@@ -98,6 +98,24 @@ def test_range_statement(session: Session):
)
+def test_literal_complexity_for_snowflake_values(session: Session):
+ from snowflake.snowpark._internal.analyzer import analyzer
+
+ df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
+ assert_df_subtree_query_complexity(
+ df1, {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LITERAL: 4}
+ )
+
+ try:
+ original_threshold = analyzer.ARRAY_BIND_THRESHOLD
+ analyzer.ARRAY_BIND_THRESHOLD = 2
+ df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
+ # SELECT "A", "B" from (SELECT * FROM TEMP_TABLE)
+ assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 3})
+ finally:
+ analyzer.ARRAY_BIND_THRESHOLD = original_threshold
+
+
def test_generator_table_function(session: Session):
df1 = session.generator(
seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150
diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py
index df0afc1099b..21e77883338 100644
--- a/tests/integ/test_session.py
+++ b/tests/integ/test_session.py
@@ -5,6 +5,7 @@
import os
from functools import partial
+from unittest.mock import patch
import pytest
@@ -719,6 +720,31 @@ def test_eliminate_numeric_sql_value_cast_optimization_enabled_on_session(
new_session.eliminate_numeric_sql_value_cast_enabled = None
+def test_large_query_breakdown_complexity_bounds(session):
+ original_bounds = session.large_query_breakdown_complexity_bounds
+ try:
+ with pytest.raises(ValueError, match="Expecting a tuple of two integers"):
+ session.large_query_breakdown_complexity_bounds = (1, 2, 3)
+
+ with pytest.raises(
+ ValueError, match="Expecting a tuple of lower and upper bound"
+ ):
+ session.large_query_breakdown_complexity_bounds = (3, 2)
+
+ with patch.object(
+ session._conn._telemetry_client,
+ "send_large_query_breakdown_update_complexity_bounds",
+ ) as patch_send:
+ session.large_query_breakdown_complexity_bounds = (1, 2)
+ assert session.large_query_breakdown_complexity_bounds == (1, 2)
+ assert patch_send.call_count == 1
+ assert patch_send.call_args[0][0] == session.session_id
+ assert patch_send.call_args[0][1] == 1
+ assert patch_send.call_args[0][2] == 2
+ finally:
+ session.large_query_breakdown_complexity_bounds = original_bounds
+
+
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_create_session_from_default_config_file(monkeypatch, db_parameters):
import tomlkit
diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py
index bcfa2cfa512..39749de76f6 100644
--- a/tests/integ/test_telemetry.py
+++ b/tests/integ/test_telemetry.py
@@ -5,6 +5,7 @@
import decimal
import sys
+import uuid
from functools import partial
from typing import Any, Dict, Tuple
@@ -599,6 +600,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled):
{
"name": "Session.range",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": df._plan.uuid,
"query_plan_height": query_plan_height,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {
@@ -621,6 +623,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled):
{
"name": "Session.range",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": df._plan.uuid,
"query_plan_height": query_plan_height,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {
@@ -643,6 +646,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled):
{
"name": "Session.range",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": df._plan.uuid,
"query_plan_height": query_plan_height,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {
@@ -665,6 +669,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled):
{
"name": "Session.range",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": df._plan.uuid,
"query_plan_height": query_plan_height,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {
@@ -687,6 +692,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled):
{
"name": "Session.range",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": df._plan.uuid,
"query_plan_height": query_plan_height,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {
@@ -829,10 +835,15 @@ def test_dataframe_stat_functions_api_calls(session):
column = 6 if session.sql_simplifier_enabled else 9
crosstab = df.stat.crosstab("empid", "month")
+ # uuid here is generated by an intermediate dataframe in crosstab implementation
+ # therefore we can't predict it. We check that the uuid for crosstab is same as
+ # that for df.
+ uuid = df._plan.api_calls[0]["plan_uuid"]
assert crosstab._plan.api_calls == [
{
"name": "Session.create_dataframe[values]",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": uuid,
"query_plan_height": 4,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {"group_by": 1, "column": column, "literal": 48},
@@ -851,6 +862,7 @@ def test_dataframe_stat_functions_api_calls(session):
{
"name": "Session.create_dataframe[values]",
"sql_simplifier_enabled": session.sql_simplifier_enabled,
+ "plan_uuid": uuid,
"query_plan_height": 4,
"query_plan_num_duplicate_nodes": 0,
"query_plan_complexity": {"group_by": 1, "column": column, "literal": 48},
@@ -1166,3 +1178,96 @@ def send_large_query_optimization_skipped_telemetry():
)
assert data == expected_data
assert type_ == "snowpark_large_query_breakdown_optimization_skipped"
+
+
+def test_post_compilation_stage_telemetry(session):
+ client = session._conn._telemetry_client
+ uuid_str = str(uuid.uuid4())
+
+ def send_telemetry():
+ summary_value = {
+ "cte_optimization_enabled": True,
+ "large_query_breakdown_enabled": True,
+ "complexity_score_bounds": (300, 600),
+ "time_taken_for_compilation": 0.136,
+ "time_taken_for_deep_copy_plan": 0.074,
+ "time_taken_for_cte_optimization": 0.01,
+ "time_taken_for_large_query_breakdown": 0.062,
+ "complexity_score_before_compilation": 1148,
+ "complexity_score_after_cte_optimization": [1148],
+ "complexity_score_after_large_query_breakdown": [514, 636],
+ }
+ client.send_query_compilation_summary_telemetry(
+ session_id=session.session_id,
+ plan_uuid=uuid_str,
+ compilation_stage_summary=summary_value,
+ )
+
+ telemetry_tracker = TelemetryDataTracker(session)
+
+ expected_data = {
+ "session_id": session.session_id,
+ "plan_uuid": uuid_str,
+ "cte_optimization_enabled": True,
+ "large_query_breakdown_enabled": True,
+ "complexity_score_bounds": (300, 600),
+ "time_taken_for_compilation": 0.136,
+ "time_taken_for_deep_copy_plan": 0.074,
+ "time_taken_for_cte_optimization": 0.01,
+ "time_taken_for_large_query_breakdown": 0.062,
+ "complexity_score_before_compilation": 1148,
+ "complexity_score_after_cte_optimization": [1148],
+ "complexity_score_after_large_query_breakdown": [514, 636],
+ }
+
+ data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry)
+ assert data == expected_data
+ assert type_ == "snowpark_compilation_stage_statistics"
+
+
+def test_temp_table_cleanup(session):
+ client = session._conn._telemetry_client
+
+ def send_telemetry():
+ client.send_temp_table_cleanup_telemetry(
+ session.session_id,
+ temp_table_cleaner_enabled=True,
+ num_temp_tables_cleaned=2,
+ num_temp_tables_created=5,
+ )
+
+ telemetry_tracker = TelemetryDataTracker(session)
+
+ expected_data = {
+ "session_id": session.session_id,
+ "temp_table_cleaner_enabled": True,
+ "num_temp_tables_cleaned": 2,
+ "num_temp_tables_created": 5,
+ }
+
+ data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry)
+ assert data == expected_data
+ assert type_ == "snowpark_temp_table_cleanup"
+
+
+def test_temp_table_cleanup_exception(session):
+ client = session._conn._telemetry_client
+
+ def send_telemetry():
+ client.send_temp_table_cleanup_abnormal_exception_telemetry(
+ session.session_id,
+ table_name="table_name_placeholder",
+ exception_message="exception_message_placeholder",
+ )
+
+ telemetry_tracker = TelemetryDataTracker(session)
+
+ expected_data = {
+ "session_id": session.session_id,
+ "temp_table_cleanup_abnormal_exception_table_name": "table_name_placeholder",
+ "temp_table_cleanup_abnormal_exception_message": "exception_message_placeholder",
+ }
+
+ data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry)
+ assert data == expected_data
+ assert type_ == "snowpark_temp_table_cleanup_abnormal_exception"
diff --git a/tests/integ/test_temp_table_cleanup.py b/tests/integ/test_temp_table_cleanup.py
index 4ac87661484..cdd97d49937 100644
--- a/tests/integ/test_temp_table_cleanup.py
+++ b/tests/integ/test_temp_table_cleanup.py
@@ -12,6 +12,7 @@
from snowflake.snowpark._internal.utils import (
TempObjectType,
random_name_for_temp_object,
+ warning_dict,
)
from snowflake.snowpark.functions import col
from tests.utils import IS_IN_STORED_PROC
@@ -25,40 +26,61 @@
WAIT_TIME = 1
+@pytest.fixture(autouse=True)
+def setup(session):
+ auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled
+ session.auto_clean_up_temp_table_enabled = True
+ yield
+ session.auto_clean_up_temp_table_enabled = auto_clean_up_temp_table_enabled
+
+
def test_basic(session):
+ session._temp_table_auto_cleaner.ref_count_map.clear()
df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result()
table_name = df1.table_name
table_ids = table_name.split(".")
df1.collect()
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
df2 = df1.select("*").filter(col("a") == 1)
df2.collect()
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
df3 = df1.union_all(df2)
df3.collect()
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
- session._temp_table_auto_cleaner.start()
del df1
gc.collect()
time.sleep(WAIT_TIME)
assert session._table_exists(table_ids)
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
del df2
gc.collect()
time.sleep(WAIT_TIME)
assert session._table_exists(table_ids)
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
del df3
gc.collect()
time.sleep(WAIT_TIME)
assert not session._table_exists(table_ids)
- assert table_name not in session._temp_table_auto_cleaner.ref_count_map
+ assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
def test_function(session):
+ session._temp_table_auto_cleaner.ref_count_map.clear()
table_name = None
def f(session: Session) -> None:
@@ -68,13 +90,16 @@ def f(session: Session) -> None:
nonlocal table_name
table_name = df.table_name
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
- session._temp_table_auto_cleaner.start()
f(session)
gc.collect()
time.sleep(WAIT_TIME)
assert not session._table_exists(table_name.split("."))
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
@pytest.mark.parametrize(
@@ -86,33 +111,42 @@ def f(session: Session) -> None:
],
)
def test_copy(session, copy_function):
+ session._temp_table_auto_cleaner.ref_count_map.clear()
df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result()
table_name = df1.table_name
table_ids = table_name.split(".")
df1.collect()
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
df2 = copy_function(df1).select("*").filter(col("a") == 1)
df2.collect()
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 2
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
- session._temp_table_auto_cleaner.start()
del df1
gc.collect()
time.sleep(WAIT_TIME)
assert session._table_exists(table_ids)
assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
- session._temp_table_auto_cleaner.start()
del df2
gc.collect()
time.sleep(WAIT_TIME)
assert not session._table_exists(table_ids)
- assert table_name not in session._temp_table_auto_cleaner.ref_count_map
+ assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_reference_count_map_multiple_sessions(db_parameters, session):
+ session._temp_table_auto_cleaner.ref_count_map.clear()
new_session = Session.builder.configs(db_parameters).create()
+ new_session.auto_clean_up_temp_table_enabled = True
try:
df1 = session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
@@ -120,43 +154,59 @@ def test_reference_count_map_multiple_sessions(db_parameters, session):
table_name1 = df1.table_name
table_ids1 = table_name1.split(".")
assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 1
- assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
+ assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 0
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
df2 = new_session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).cache_result()
table_name2 = df2.table_name
table_ids2 = table_name2.split(".")
- assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0
+ assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 1
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
- session._temp_table_auto_cleaner.start()
del df1
gc.collect()
time.sleep(WAIT_TIME)
assert not session._table_exists(table_ids1)
assert new_session._table_exists(table_ids2)
assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0
- assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
+ assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
- new_session._temp_table_auto_cleaner.start()
del df2
gc.collect()
time.sleep(WAIT_TIME)
assert not new_session._table_exists(table_ids2)
- assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0
+ assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
finally:
new_session.close()
def test_save_as_table_no_drop(session):
- session._temp_table_auto_cleaner.start()
+ session._temp_table_auto_cleaner.ref_count_map.clear()
def f(session: Session, temp_table_name: str) -> None:
session.create_dataframe(
[[1, 2], [3, 4]], schema=["a", "b"]
).write.save_as_table(temp_table_name, table_type="temp")
- assert session._temp_table_auto_cleaner.ref_count_map[temp_table_name] == 0
+ assert temp_table_name not in session._temp_table_auto_cleaner.ref_count_map
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
f(session, temp_table_name)
@@ -165,34 +215,25 @@ def f(session: Session, temp_table_name: str) -> None:
assert session._table_exists([temp_table_name])
-def test_start_stop(session):
- session._temp_table_auto_cleaner.stop()
-
- df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result()
- table_name = df1.table_name
+def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog):
+ warning_dict.clear()
+ with caplog.at_level(logging.WARNING):
+ session.auto_clean_up_temp_table_enabled = False
+ assert session.auto_clean_up_temp_table_enabled is False
+ assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text
+ df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result()
+ table_name = df.table_name
table_ids = table_name.split(".")
- assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1
- del df1
+ del df
gc.collect()
- assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0
- assert not session._temp_table_auto_cleaner.queue.empty()
- assert session._table_exists(table_ids)
-
- session._temp_table_auto_cleaner.start()
time.sleep(WAIT_TIME)
- assert session._temp_table_auto_cleaner.queue.empty()
- assert not session._table_exists(table_ids)
-
-
-def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog):
- with caplog.at_level(logging.WARNING):
- session.auto_clean_up_temp_table_enabled = True
+ assert session._table_exists(table_ids)
+ assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0
+ assert session._temp_table_auto_cleaner.num_temp_tables_created == 1
+ assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1
+ session.auto_clean_up_temp_table_enabled = True
assert session.auto_clean_up_temp_table_enabled is True
- assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text
- assert session._temp_table_auto_cleaner.is_alive()
- session.auto_clean_up_temp_table_enabled = False
- assert session.auto_clean_up_temp_table_enabled is False
- assert not session._temp_table_auto_cleaner.is_alive()
+
with pytest.raises(
ValueError,
match="value for auto_clean_up_temp_table_enabled must be True or False!",
diff --git a/tests/mock/test_multithreading.py b/tests/mock/test_multithreading.py
new file mode 100644
index 00000000000..5e0078212d6
--- /dev/null
+++ b/tests/mock/test_multithreading.py
@@ -0,0 +1,335 @@
+#
+# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
+#
+
+import io
+import json
+import os
+import tempfile
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from threading import Thread
+
+import pytest
+
+from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
+ LogicalPlan,
+ SaveMode,
+)
+from snowflake.snowpark._internal.utils import normalize_local_file
+from snowflake.snowpark.functions import lit, when_matched
+from snowflake.snowpark.mock._connection import MockServerConnection
+from snowflake.snowpark.mock._functions import MockedFunctionRegistry
+from snowflake.snowpark.mock._plan import MockExecutionPlan
+from snowflake.snowpark.mock._snowflake_data_type import TableEmulator
+from snowflake.snowpark.mock._stage_registry import StageEntityRegistry
+from snowflake.snowpark.mock._telemetry import LocalTestOOBTelemetryService
+from snowflake.snowpark.row import Row
+from snowflake.snowpark.session import Session
+from tests.utils import Utils
+
+
+def test_table_update_merge_delete(session):
+ table_name = Utils.random_table_name()
+ num_threads = 10
+ data = [[v, 11 * v] for v in range(10)]
+ df = session.create_dataframe(data, schema=["a", "b"])
+ df.write.save_as_table(table_name, table_type="temp")
+
+ source_df = df
+ t = session.table(table_name)
+
+ def update_table(thread_id: int):
+ t.update({"b": 0}, t.a == lit(thread_id))
+
+ def merge_table(thread_id: int):
+ t.merge(
+ source_df, t.a == source_df.a, [when_matched().update({"b": source_df.b})]
+ )
+
+ def delete_table(thread_id: int):
+ t.delete(t.a == lit(thread_id))
+
+ # all threads will update column b to 0 where a = thread_id
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ # update
+ futures = [executor.submit(update_table, i) for i in range(num_threads)]
+ for future in as_completed(futures):
+ future.result()
+
+ # all threads will set column b to 0
+ Utils.check_answer(t.select(t.b), [Row(B=0) for _ in range(10)])
+
+ # merge
+ futures = [executor.submit(merge_table, i) for i in range(num_threads)]
+ for future in as_completed(futures):
+ future.result()
+
+ # all threads will set column b to 11 * a
+ Utils.check_answer(t.select(t.b), [Row(B=11 * i) for i in range(10)])
+
+ # delete
+ futures = [executor.submit(delete_table, i) for i in range(num_threads)]
+ for future in as_completed(futures):
+ future.result()
+
+ # all threads will delete their row
+ assert t.count() == 0
+
+
+def test_udf_register_and_invoke(session):
+ df = session.create_dataframe([[1], [2]], schema=["num"])
+ num_threads = 10
+
+ def register_udf(x: int):
+ def echo(x: int) -> int:
+ return x
+
+ return session.udf.register(echo, name="echo", replace=True)
+
+ def invoke_udf():
+ result = df.select(session.udf.call_udf("echo", df.num)).collect()
+ assert result[0][0] == 1
+ assert result[1][0] == 2
+
+ threads = []
+ for i in range(num_threads):
+ thread_register = Thread(target=register_udf, args=(i,))
+ threads.append(thread_register)
+ thread_register.start()
+
+ thread_invoke = Thread(target=invoke_udf)
+ threads.append(thread_invoke)
+ thread_invoke.start()
+
+ for thread in threads:
+ thread.join()
+
+
+def test_sp_register_and_invoke(session):
+ num_threads = 10
+
+ def increment_by_one_fn(session_: Session, x: int) -> int:
+ return x + 1
+
+ def register_sproc():
+ session.sproc.register(
+ increment_by_one_fn, name="increment_by_one", replace=True
+ )
+
+ def invoke_sproc():
+ result = session.call("increment_by_one", 1)
+ assert result == 2
+
+ threads = []
+ for i in range(num_threads):
+ thread_register = Thread(target=register_sproc, args=(i,))
+ threads.append(thread_register)
+ thread_register.start()
+
+ thread_invoke = Thread(target=invoke_sproc)
+ threads.append(thread_invoke)
+ thread_invoke.start()
+
+ for thread in threads:
+ thread.join()
+
+
+def test_mocked_function_registry_created_once():
+ num_threads = 10
+
+ result = []
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ futures = [
+ executor.submit(MockedFunctionRegistry.get_or_create)
+ for _ in range(num_threads)
+ ]
+
+ for future in as_completed(futures):
+ result.append(future.result())
+
+ registry = MockedFunctionRegistry.get_or_create()
+ assert all([registry is r for r in result])
+
+
+@pytest.mark.parametrize("test_table", [True, False])
+def test_tabular_entity_registry(test_table):
+ conn = MockServerConnection()
+ entity_registry = conn.entity_registry
+ num_threads = 10
+
+ def write_read_and_drop_table():
+ table_name = Utils.random_table_name()
+ table_emulator = TableEmulator()
+
+ entity_registry.write_table(table_name, table_emulator, SaveMode.OVERWRITE)
+
+ optional_table = entity_registry.read_table_if_exists(table_name)
+ if optional_table is not None:
+ assert optional_table.empty
+
+ entity_registry.drop_table(table_name)
+
+ def write_read_and_drop_view():
+ view_name = Utils.random_view_name()
+ empty_logical_plan = LogicalPlan()
+ plan = MockExecutionPlan(empty_logical_plan, None)
+
+ entity_registry.create_or_replace_view(plan, view_name)
+
+ optional_view = entity_registry.read_view_if_exists(view_name)
+ if optional_view is not None:
+ assert optional_view.source_plan == empty_logical_plan
+
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ if test_table:
+ test_fn = write_read_and_drop_table
+ else:
+ test_fn = write_read_and_drop_view
+ futures = [executor.submit(test_fn) for _ in range(num_threads)]
+
+ for future in as_completed(futures):
+ future.result()
+
+
+def test_stage_entity_registry_put_and_get():
+ stage_registry = StageEntityRegistry(MockServerConnection())
+ num_threads = 10
+
+ def put_and_get_file():
+ stage_registry.put(
+ normalize_local_file(
+ f"{os.path.dirname(os.path.abspath(__file__))}/files/test_file_1"
+ ),
+ "@test_stage/test_parent_dir/test_child_dir",
+ )
+ with tempfile.TemporaryDirectory() as temp_dir:
+ stage_registry.get(
+ "@test_stage/test_parent_dir/test_child_dir/test_file_1",
+ temp_dir,
+ )
+ assert os.path.isfile(os.path.join(temp_dir, "test_file_1"))
+
+ threads = []
+ for _ in range(num_threads):
+ thread = Thread(target=put_and_get_file)
+ threads.append(thread)
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+
+
+def test_stage_entity_registry_upload_and_read(session):
+ stage_registry = StageEntityRegistry(MockServerConnection())
+ num_threads = 10
+
+ def upload_and_read_json(thread_id: int):
+ json_string = json.dumps({"thread_id": thread_id})
+ bytes_io = io.BytesIO(json_string.encode("utf-8"))
+ stage_registry.upload_stream(
+ input_stream=bytes_io,
+ stage_location="@test_stage/test_parent_dir",
+ file_name=f"test_file_{thread_id}",
+ )
+
+ df = stage_registry.read_file(
+ f"@test_stage/test_parent_dir/test_file_{thread_id}",
+ "json",
+ [],
+ session._analyzer,
+ {"INFER_SCHEMA": "True"},
+ )
+
+ assert df['"thread_id"'].iloc[0] == thread_id
+
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ futures = [executor.submit(upload_and_read_json, i) for i in range(num_threads)]
+
+ for future in as_completed(futures):
+ future.result()
+
+
+def test_stage_entity_registry_create_or_replace():
+ stage_registry = StageEntityRegistry(MockServerConnection())
+ num_threads = 10
+
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
+ futures = [
+ executor.submit(stage_registry.create_or_replace_stage, f"test_stage_{i}")
+ for i in range(num_threads)
+ ]
+
+ for future in as_completed(futures):
+ future.result()
+
+ assert len(stage_registry._stage_registry) == num_threads
+ for i in range(num_threads):
+ assert f"test_stage_{i}" in stage_registry._stage_registry
+
+
+def test_oob_telemetry_add():
+ oob_service = LocalTestOOBTelemetryService.get_instance()
+ # clean up queue first
+ oob_service.export_queue_to_string()
+ num_threads = 10
+ num_events_per_thread = 10
+
+ # create a function that adds 10 events to the queue
+ def add_events(thread_id: int):
+ for i in range(num_events_per_thread):
+ oob_service.add(
+ {f"thread_{thread_id}_event_{i}": f"dummy_event_{thread_id}_{i}"}
+ )
+
+ # set batch_size to 101
+ is_enabled = oob_service.enabled
+ oob_service.enable()
+ original_batch_size = oob_service.batch_size
+ oob_service.batch_size = num_threads * num_events_per_thread + 1
+ try:
+ # create 10 threads
+ threads = []
+ for thread_id in range(num_threads):
+ thread = Thread(target=add_events, args=(thread_id,))
+ threads.append(thread)
+ thread.start()
+
+ # wait for all threads to finish
+ for thread in threads:
+ thread.join()
+
+ # assert that the queue size is 100
+ assert oob_service.queue.qsize() == num_threads * num_events_per_thread
+ finally:
+ oob_service.batch_size = original_batch_size
+ if not is_enabled:
+ oob_service.disable()
+
+
+def test_oob_telemetry_flush():
+ oob_service = LocalTestOOBTelemetryService.get_instance()
+ # clean up queue first
+ oob_service.export_queue_to_string()
+
+ is_enabled = oob_service.enabled
+ oob_service.enable()
+ # add a dummy event
+ oob_service.add({"event": "dummy_event"})
+
+ try:
+ # flush the queue in multiple threads
+ num_threads = 10
+ threads = []
+ for _ in range(num_threads):
+ thread = Thread(target=oob_service.flush)
+ threads.append(thread)
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+
+ # assert that the queue is empty
+ assert oob_service.size() == 0
+ finally:
+ if not is_enabled:
+ oob_service.disable()
diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py
index 7c5e3a40bb0..d94c80b8d67 100644
--- a/tests/unit/modin/modin/test_envvars.py
+++ b/tests/unit/modin/modin/test_envvars.py
@@ -166,6 +166,7 @@ def test_overrides(self):
# Test for pandas doc when function is not defined on module.
assert pandas.read_table.__doc__ in pd.read_table.__doc__
+ @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON)
def test_not_redefining_classes_modin_issue_7138(self):
original_dataframe_class = pd.DataFrame
_init_doc_module()
diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py
index 5434387ba71..6c9edfd024f 100644
--- a/tests/unit/modin/test_aggregation_utils.py
+++ b/tests/unit/modin/test_aggregation_utils.py
@@ -2,12 +2,20 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
+from types import MappingProxyType
+from unittest import mock
+
import numpy as np
import pytest
+import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils
+from snowflake.snowpark.functions import greatest, sum as sum_
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
+ SnowflakeAggFunc,
+ _is_supported_snowflake_agg_func,
+ _SnowparkPandasAggregation,
check_is_aggregation_supported_in_snowflake,
- is_supported_snowflake_agg_func,
+ get_snowflake_agg_func,
)
@@ -53,8 +61,8 @@
("quantile", {}, 1, False),
],
)
-def test_is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None:
- assert is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid
+def test__is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None:
+ assert _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid
@pytest.mark.parametrize(
@@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args(
agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0
)
assert can_be_distributed == expected_result
+
+
+@pytest.mark.parametrize(
+ "agg_func, agg_kwargs, axis, expected",
+ [
+ (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)),
+ ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)),
+ ("test", {}, 0, None),
+ ],
+)
+def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected):
+ result = get_snowflake_agg_func(agg_func, agg_kwargs, axis)
+ if expected is None:
+ assert result is None
+ else:
+ assert result == expected
+
+
+def test_get_snowflake_agg_func_with_no_implementation_on_axis_0():
+ """Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0."""
+ # We have to patch the internal dictionary
+ # _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is
+ # no real function that we support on axis=1 but not on axis=0.
+ with mock.patch.object(
+ aggregation_utils,
+ "_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION",
+ MappingProxyType(
+ {
+ "max": _SnowparkPandasAggregation(
+ preserves_snowpark_pandas_types=True,
+ axis_1_aggregation_keepna=greatest,
+ axis_1_aggregation_skipna=greatest,
+ )
+ }
+ ),
+ ):
+ assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None
diff --git a/tests/unit/modin/test_series_dt.py b/tests/unit/modin/test_series_dt.py
index be0039683a8..0b5572f0592 100644
--- a/tests/unit/modin/test_series_dt.py
+++ b/tests/unit/modin/test_series_dt.py
@@ -32,8 +32,6 @@ def mock_query_compiler_for_dt_series() -> SnowflakeQueryCompiler:
[
(lambda s: s.dt.timetz, "timetz"),
(lambda s: s.dt.to_period(), "to_period"),
- (lambda s: s.dt.tz_localize(tz="UTC"), "tz_localize"),
- (lambda s: s.dt.tz_convert(tz="UTC"), "tz_convert"),
(lambda s: s.dt.strftime(date_format="YY/MM/DD"), "strftime"),
(lambda s: s.dt.qyear, "qyear"),
(lambda s: s.dt.start_time, "start_time"),
diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py
index c31e5cc6290..c9b8a1ce38d 100644
--- a/tests/unit/test_expression_dependent_columns.py
+++ b/tests/unit/test_expression_dependent_columns.py
@@ -87,30 +87,37 @@
def test_expression():
a = Expression()
assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY
+ assert a.dependent_column_names_with_duplication() == []
b = Expression(child=UnresolvedAttribute("a"))
assert b.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY
+ assert b.dependent_column_names_with_duplication() == []
# root class Expression always returns empty dependency
def test_literal():
a = Literal(5)
assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY
+ assert a.dependent_column_names_with_duplication() == []
def test_attribute():
a = Attribute("A", IntegerType())
assert a.dependent_column_names() == {"A"}
+ assert a.dependent_column_names_with_duplication() == ["A"]
def test_unresolved_attribute():
a = UnresolvedAttribute("A")
assert a.dependent_column_names() == {"A"}
+ assert a.dependent_column_names_with_duplication() == ["A"]
b = UnresolvedAttribute("a > 1", is_sql_text=True)
assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL
+ assert b.dependent_column_names_with_duplication() == []
c = UnresolvedAttribute("$1 > 1", is_sql_text=True)
assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR
+ assert c.dependent_column_names_with_duplication() == ["$"]
def test_case_when():
@@ -118,46 +125,85 @@ def test_case_when():
b = Column("b")
z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e"))
assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'}
+ # verify column '"A"', '"B"' occurred twice in the dependency columns
+ assert z._expression.dependent_column_names_with_duplication() == [
+ '"A"',
+ '"B"',
+ '"C"',
+ '"A"',
+ '"B"',
+ '"D"',
+ '"E"',
+ ]
def test_collate():
a = Collate(UnresolvedAttribute("a"), "spec")
assert a.dependent_column_names() == {"a"}
+ assert a.dependent_column_names_with_duplication() == ["a"]
def test_function_expression():
a = FunctionExpression("test_func", [UnresolvedAttribute(x) for x in "abcd"], False)
assert a.dependent_column_names() == set("abcd")
+ assert a.dependent_column_names_with_duplication() == list("abcd")
+
+ # expressions with duplicated dependent column
+ b = FunctionExpression(
+ "test_func", [UnresolvedAttribute(x) for x in "abcdad"], False
+ )
+ assert b.dependent_column_names() == set("abcd")
+ assert b.dependent_column_names_with_duplication() == list("abcdad")
def test_in_expression():
a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"])
assert a.dependent_column_names() == set("abcde")
+ assert a.dependent_column_names_with_duplication() == list("eabcd")
def test_like():
a = Like(UnresolvedAttribute("a"), UnresolvedAttribute("b"))
assert a.dependent_column_names() == {"a", "b"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b"]
+
+ # with duplication
+ b = Like(UnresolvedAttribute("a"), UnresolvedAttribute("a"))
+ assert b.dependent_column_names() == {"a"}
+ assert b.dependent_column_names_with_duplication() == ["a", "a"]
def test_list_agg():
a = ListAgg(UnresolvedAttribute("a"), ",", True)
assert a.dependent_column_names() == {"a"}
+ assert a.dependent_column_names_with_duplication() == ["a"]
def test_multiple_expression():
a = MultipleExpression([UnresolvedAttribute(x) for x in "abcd"])
assert a.dependent_column_names() == set("abcd")
+ assert a.dependent_column_names_with_duplication() == list("abcd")
+
+ # with duplication
+ a = MultipleExpression([UnresolvedAttribute(x) for x in "abcdbea"])
+ assert a.dependent_column_names() == set("abcde")
+ assert a.dependent_column_names_with_duplication() == list("abcdbea")
def test_reg_exp():
a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b"))
assert a.dependent_column_names() == {"a", "b"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b"]
+
+ b = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("a"))
+ assert b.dependent_column_names() == {"a"}
+ assert b.dependent_column_names_with_duplication() == ["a", "a"]
def test_scalar_subquery():
a = ScalarSubquery(None)
assert a.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR
+ assert a.dependent_column_names_with_duplication() == list(COLUMN_DEPENDENCY_DOLLAR)
def test_snowflake_udf():
@@ -165,21 +211,42 @@ def test_snowflake_udf():
"udf_name", [UnresolvedAttribute(x) for x in "abcd"], IntegerType()
)
assert a.dependent_column_names() == set("abcd")
+ assert a.dependent_column_names_with_duplication() == list("abcd")
+
+ # with duplication
+ b = SnowflakeUDF(
+ "udf_name", [UnresolvedAttribute(x) for x in "abcdfc"], IntegerType()
+ )
+ assert b.dependent_column_names() == set("abcdf")
+ assert b.dependent_column_names_with_duplication() == list("abcdfc")
def test_star():
a = Star([Attribute(x, IntegerType()) for x in "abcd"])
assert a.dependent_column_names() == set("abcd")
+ assert a.dependent_column_names_with_duplication() == list("abcd")
+
+ b = Star([])
+ assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL
+ assert b.dependent_column_names_with_duplication() == []
def test_subfield_string():
a = SubfieldString(UnresolvedAttribute("a"), "field")
assert a.dependent_column_names() == {"a"}
+ assert a.dependent_column_names_with_duplication() == ["a"]
def test_within_group():
a = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"])
assert a.dependent_column_names() == set("abcde")
+ assert a.dependent_column_names_with_duplication() == list("eabcd")
+
+ b = WithinGroup(
+ UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"]
+ )
+ assert b.dependent_column_names() == set("abcde")
+ assert b.dependent_column_names_with_duplication() == list("eabcdea")
@pytest.mark.parametrize(
@@ -189,16 +256,19 @@ def test_within_group():
def test_unary_expression(expression_class):
a = expression_class(child=UnresolvedAttribute("a"))
assert a.dependent_column_names() == {"a"}
+ assert a.dependent_column_names_with_duplication() == ["a"]
def test_alias():
a = Alias(child=Add(UnresolvedAttribute("a"), UnresolvedAttribute("b")), name="c")
assert a.dependent_column_names() == {"a", "b"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b"]
def test_cast():
a = Cast(UnresolvedAttribute("a"), IntegerType())
assert a.dependent_column_names() == {"a"}
+ assert a.dependent_column_names_with_duplication() == ["a"]
@pytest.mark.parametrize(
@@ -234,6 +304,19 @@ def test_binary_expression(expression_class):
assert b.dependent_column_names() == {"B"}
assert binary_expression.dependent_column_names() == {"A", "B"}
+ assert a.dependent_column_names_with_duplication() == ["A"]
+ assert b.dependent_column_names_with_duplication() == ["B"]
+ assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"]
+
+ # hierarchical expressions with duplication
+ hierarchical_binary_expression = expression_class(expression_class(a, b), b)
+ assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"}
+ assert hierarchical_binary_expression.dependent_column_names_with_duplication() == [
+ "A",
+ "B",
+ "B",
+ ]
+
@pytest.mark.parametrize(
"expression_class",
@@ -253,6 +336,18 @@ def test_grouping_set(expression_class):
]
)
assert a.dependent_column_names() == {"a", "b", "c", "d"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"]
+
+ # with duplication
+ b = expression_class(
+ [
+ UnresolvedAttribute("a"),
+ UnresolvedAttribute("a"),
+ UnresolvedAttribute("c"),
+ ]
+ )
+ assert b.dependent_column_names() == {"a", "c"}
+ assert b.dependent_column_names_with_duplication() == ["a", "a", "c"]
def test_grouping_sets_expression():
@@ -263,11 +358,13 @@ def test_grouping_sets_expression():
]
)
assert a.dependent_column_names() == {"a", "b", "c", "d"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"]
def test_sort_order():
a = SortOrder(UnresolvedAttribute("a"), Ascending())
assert a.dependent_column_names() == {"a"}
+ assert a.dependent_column_names_with_duplication() == ["a"]
def test_specified_window_frame():
@@ -275,12 +372,21 @@ def test_specified_window_frame():
RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("b")
)
assert a.dependent_column_names() == {"a", "b"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b"]
+
+ # with duplication
+ b = SpecifiedWindowFrame(
+ RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("a")
+ )
+ assert b.dependent_column_names() == {"a"}
+ assert b.dependent_column_names_with_duplication() == ["a", "a"]
@pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead])
def test_rank_related_function_expression(expression_class):
a = expression_class(UnresolvedAttribute("a"), 1, UnresolvedAttribute("b"), False)
assert a.dependent_column_names() == {"a", "b"}
+ assert a.dependent_column_names_with_duplication() == ["a", "b"]
def test_window_spec_definition():
@@ -295,6 +401,7 @@ def test_window_spec_definition():
),
)
assert a.dependent_column_names() == set("abcdef")
+ assert a.dependent_column_names_with_duplication() == list("abcdef")
def test_window_expression():
@@ -310,6 +417,23 @@ def test_window_expression():
)
a = WindowExpression(UnresolvedAttribute("x"), window_spec_definition)
assert a.dependent_column_names() == set("abcdefx")
+ assert a.dependent_column_names_with_duplication() == list("xabcdef")
+
+
+def test_window_expression_with_duplication_columns():
+ window_spec_definition = WindowSpecDefinition(
+ [UnresolvedAttribute("a"), UnresolvedAttribute("b")],
+ [
+ SortOrder(UnresolvedAttribute("c"), Ascending()),
+ SortOrder(UnresolvedAttribute("a"), Ascending()),
+ ],
+ SpecifiedWindowFrame(
+ RowFrame(), UnresolvedAttribute("e"), UnresolvedAttribute("f")
+ ),
+ )
+ a = WindowExpression(UnresolvedAttribute("e"), window_spec_definition)
+ assert a.dependent_column_names() == set("abcef")
+ assert a.dependent_column_names_with_duplication() == list("eabcaef")
@pytest.mark.parametrize(
@@ -325,3 +449,4 @@ def test_window_expression():
def test_other_window_expressions(expression_class):
a = expression_class()
assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY
+ assert a.dependent_column_names_with_duplication() == []
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index 262c9e82c44..370ee455d62 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -112,6 +112,7 @@ def test_used_scoped_temp_object():
def test_close_exception():
fake_connection = mock.create_autospec(ServerConnection)
fake_connection._conn = mock.Mock()
+ fake_connection._telemetry_client = mock.Mock()
fake_connection.is_closed = MagicMock(return_value=False)
exception_msg = "Mock exception for session.cancel_all"
fake_connection.run_query = MagicMock(side_effect=Exception(exception_msg))