From 6879ac8099196ba6b70ccdb7a8c1c6d0687a1e41 Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Wed, 11 Sep 2024 13:20:48 -0700 Subject: [PATCH 1/5] SNOW-1641472 Refactor binary ops utility to improve readability (#2257) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1641472 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. Please write a short description of how your code change solves the related issue. Try to improve readability of binary ops by reorganize it as a `BinaryOp` class. This pull request focuses on improving the readability and maintainability of the binary operation utility code within the Snowflake query compiler. The changes primarily involve refactoring the binary operation functions into a new `BinaryOp` class and updating the relevant method calls throughout the codebase. ### Refactoring binary operations: * [`src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py`](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L187-R187): Replaced multiple function calls (`compute_binary_op_between_scalar_and_snowpark_column`, `compute_binary_op_between_snowpark_column_and_scalar`, `compute_binary_op_between_snowpark_columns`, `compute_binary_op_with_fill_value`, `is_binary_op_supported`) with the new `BinaryOp` class methods (`create`, `create_with_fill_value`, `create_with_lhs_scalar`, `create_with_rhs_scalar`, `is_binary_op_supported`). [[1]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L187-R187) [[2]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L1857-R1853) [[3]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L1866-R1862) [[4]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L1917-R1913) [[5]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L1926-R1922) [[6]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L1989-R1985) [[7]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L1998-R1994) [[8]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L2059-R2055) [[9]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L2124-R2120) [[10]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L2135-R2131) [[11]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L10551-R10547) [[12]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L10563-R10559) [[13]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L10612-R10614) [[14]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L14216-R14212) [[15]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L14246-R14244) [[16]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L14510-R14510) [[17]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L14784-R14789) [[18]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L17249-R17246) [[19]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L17779-R17774) [[20]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L17789-R17784) [[21]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L17900-R17896) [[22]](diffhunk://#diff-834ee069919510e7e410c503a8afa455154c40e65389769c08d35b0ec3f8ec03L17912-R17911) ### Documentation updates: * [`CHANGELOG.md`](diffhunk://#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4edR120): Added an entry to reflect the improved readability of the binary operation utility code. --- .../modin/plugin/_internal/binary_op_utils.py | 910 ++++++++++-------- .../compiler/snowflake_query_compiler.py | 93 +- 2 files changed, 552 insertions(+), 451 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py index 1aa81b36e64..475fbfcefa7 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py @@ -185,19 +185,6 @@ def compute_power_between_snowpark_columns( return result -def is_binary_op_supported(op: str) -> bool: - """ - check whether binary operation is mappable to Snowflake - Args - op: op as string - - Returns: - True if binary operation can be mapped to Snowflake/Snowpark, else False - """ - - return op in SUPPORTED_BINARY_OPERATIONS - - def _compute_subtraction_between_snowpark_timestamp_columns( first_operand: SnowparkColumn, first_datatype: DataType, @@ -312,314 +299,527 @@ def _op_is_between_timedelta_and_numeric( ) -def compute_binary_op_between_snowpark_columns( - op: str, - first_operand: SnowparkColumn, - first_datatype: DataTypeGetter, - second_operand: SnowparkColumn, - second_datatype: DataTypeGetter, -) -> SnowparkPandasColumn: - """ - Compute pandas binary operation for two SnowparkColumns - Args: - op: pandas operation - first_operand: SnowparkColumn for lhs - first_datatype: Callable for Snowpark Datatype for lhs - second_operand: SnowparkColumn for rhs - second_datatype: Callable for Snowpark DateType for rhs - it is not needed. +class BinaryOp: + def __init__( + self, + op: str, + first_operand: SnowparkColumn, + first_datatype: DataTypeGetter, + second_operand: SnowparkColumn, + second_datatype: DataTypeGetter, + ) -> None: + """ + Construct a BinaryOp object to compute pandas binary operation for two SnowparkColumns + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs + it is not needed. + """ + self.op = op + self.first_operand = first_operand + self.first_datatype = first_datatype + self.second_operand = second_operand + self.second_datatype = second_datatype + self.result_column = None + self.result_snowpark_pandas_type = None + + @staticmethod + def is_binary_op_supported(op: str) -> bool: + """ + check whether binary operation is mappable to Snowflake + Args + op: op as string + + Returns: + True if binary operation can be mapped to Snowflake/Snowpark, else False + """ + + return op in SUPPORTED_BINARY_OPERATIONS + + @staticmethod + def create( + op: str, + first_operand: SnowparkColumn, + first_datatype: DataTypeGetter, + second_operand: SnowparkColumn, + second_datatype: DataTypeGetter, + ) -> "BinaryOp": + """ + Create a BinaryOp object to compute pandas binary operation for two SnowparkColumns + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs + it is not needed. + """ + + def snake_to_camel(snake_str: str) -> str: + """Converts a snake case string to camel case.""" + components = snake_str.split("_") + return "".join(x.title() for x in components) + + if op in _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP: + # Normalize right-sided binary operations to the equivalent left-sided + # operations with swapped operands. For example, rsub(col(a), col(b)) + # becomes sub(col(b), col(a)) + op, first_operand, first_datatype, second_operand, second_datatype = ( + _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP[op], + second_operand, + second_datatype, + first_operand, + first_datatype, + ) - Returns: - SnowparkPandasColumn for translated pandas operation - """ - if op in _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP: - # Normalize right-sided binary operations to the equivalent left-sided - # operations with swapped operands. For example, rsub(col(a), col(b)) - # becomes sub(col(b), col(a)) - op, first_operand, first_datatype, second_operand, second_datatype = ( - _RIGHT_BINARY_OP_TO_LEFT_BINARY_OP[op], - second_operand, - second_datatype, - first_operand, - first_datatype, + class_name = f"{snake_to_camel(op)}Op" + op_class = None + for subclass in BinaryOp.__subclasses__(): + if subclass.__name__ == class_name: + op_class = subclass + if op_class is None: + op_class = BinaryOp + return op_class( + op, first_operand, first_datatype, second_operand, second_datatype ) - binary_op_result_column = None - snowpark_pandas_type = None + @staticmethod + def create_with_fill_value( + op: str, + lhs: SnowparkColumn, + lhs_datatype: DataTypeGetter, + rhs: SnowparkColumn, + rhs_datatype: DataTypeGetter, + fill_value: Scalar, + ) -> "BinaryOp": + """ + Create a BinaryOp object to compute pandas binary operation for two SnowparkColumns with fill value for missing + values. + + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs + it is not needed. + fill_value: the value to fill missing values + + Helper method for performing binary operations. + 1. Fills NaN/None values in the lhs and rhs with the given fill_value. + 2. Computes the binary operation expression for lhs rhs. + + fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs. + For instance, with fill_value = 100, + 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value. + result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110 + 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value. + result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103 + 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None. + result = lhs + rhs => None + None => None. + + Args: + op: pandas operation to perform between lhs and rhs + lhs: the lhs SnowparkColumn + lhs_datatype: Callable for Snowpark Datatype for lhs + rhs: the rhs SnowparkColumn + rhs_datatype: Callable for Snowpark Datatype for rhs + fill_value: Fill existing missing (NaN) values, and any new element needed for + successful DataFrame alignment, with this value before computation. + + Returns: + SnowparkPandasColumn for translated pandas operation + """ + lhs_cond, rhs_cond = lhs, rhs + if fill_value is not None: + fill_value_lit = pandas_lit(fill_value) + lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs) + rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs) + + return BinaryOp.create(op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype) + + @staticmethod + def create_with_rhs_scalar( + op: str, + first_operand: SnowparkColumn, + datatype: DataTypeGetter, + second_operand: Scalar, + ) -> "BinaryOp": + """ + Compute the binary operation between a Snowpark column and a scalar. + Args: + op: the name of binary operation + first_operand: The SnowparkColumn for lhs + datatype: Callable for Snowpark data type + second_operand: Scalar value + + Returns: + SnowparkPandasColumn for translated pandas operation + """ + + def second_datatype() -> DataType: + return infer_object_type(second_operand) + + return BinaryOp.create( + op, first_operand, datatype, pandas_lit(second_operand), second_datatype + ) - # some operators and the data types have to be handled specially to align with pandas - # However, it is difficult to fail early if the arithmetic operator is not compatible - # with the data type, so we just let the server raise exception (e.g. a string minus a string). - if ( - op == "add" - and isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), TimestampType) - ): - binary_op_result_column = dateadd("ns", second_operand, first_operand) - elif ( - op == "add" - and isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), TimestampType) - ): - binary_op_result_column = dateadd("ns", first_operand, second_operand) - elif op in ( - "add", - "sub", - "eq", - "ne", - "gt", - "ge", - "lt", - "le", - "floordiv", - "truediv", - ) and ( - ( - isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), NullType) + @staticmethod + def create_with_lhs_scalar( + op: str, + first_operand: Scalar, + second_operand: SnowparkColumn, + datatype: DataTypeGetter, + ) -> "BinaryOp": + """ + Compute the binary operation between a scalar and a Snowpark column. + Args: + op: the name of binary operation + first_operand: Scalar value + second_operand: The SnowparkColumn for rhs + datatype: Callable for Snowpark data type + it is not needed. + + Returns: + SnowparkPandasColumn for translated pandas operation + """ + + def first_datatype() -> DataType: + return infer_object_type(first_operand) + + return BinaryOp.create( + op, pandas_lit(first_operand), first_datatype, second_operand, datatype ) - or ( - isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), NullType) + + def _custom_compute(self) -> None: + """Implement custom compute method if needed.""" + pass + + def _get_result(self) -> SnowparkPandasColumn: + return SnowparkPandasColumn( + snowpark_column=self.result_column, + snowpark_pandas_type=self.result_snowpark_pandas_type, ) - ): - return SnowparkPandasColumn(pandas_lit(None), TimedeltaType()) - elif ( - op == "sub" - and isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), TimestampType) - ): - binary_op_result_column = dateadd("ns", -1 * second_operand, first_operand) - elif ( - op == "sub" - and isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), TimestampType) - ): + + def _check_timedelta_with_none(self) -> None: + if self.op in ( + "add", + "sub", + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + "floordiv", + "truediv", + ) and ( + ( + isinstance(self.first_datatype(), TimedeltaType) + and isinstance(self.second_datatype(), NullType) + ) + or ( + isinstance(self.second_datatype(), TimedeltaType) + and isinstance(self.first_datatype(), NullType) + ) + ): + self.result_column = pandas_lit(None) + self.result_snowpark_pandas_type = TimedeltaType() + + def _check_error(self) -> None: # Timedelta - Timestamp doesn't make sense. Raise the same error # message as pandas. - raise TypeError("bad operand type for unary -: 'DatetimeArray'") - elif op == "mod" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - binary_op_result_column = compute_modulo_between_snowpark_columns( - first_operand, first_datatype(), second_operand, second_datatype() - ) - snowpark_pandas_type = TimedeltaType() - elif op == "pow" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise TypeError("unsupported operand type for **: Timedelta") - elif op == "__or__" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise TypeError("unsupported operand type for |: Timedelta") - elif op == "__and__" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise TypeError("unsupported operand type for &: Timedelta") - elif ( - op in ("add", "sub") - and isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), TimedeltaType) - ): - snowpark_pandas_type = TimedeltaType() - elif op == "mul" and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - raise np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined] - np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]")) - ) - elif op in ( - "eq", - "ne", - "gt", - "ge", - "lt", - "le", - ) and _op_is_between_two_timedeltas_or_timedelta_and_null( - first_datatype(), second_datatype() - ): - # These operations, when done between timedeltas, work without any - # extra handling in `snowpark_pandas_type` or `binary_op_result_column`. - pass - elif op == "mul" and ( - _op_is_between_timedelta_and_numeric(first_datatype, second_datatype) - ): - binary_op_result_column = cast( - floor(first_operand * second_operand), LongType() - ) - snowpark_pandas_type = TimedeltaType() - # For `eq` and `ne`, note that Snowflake will consider 1 equal to - # Timedelta(1) because those two have the same representation in Snowflake, - # so we have to compare types in the client. - elif op == "eq" and ( - _op_is_between_timedelta_and_numeric(first_datatype, second_datatype) - ): - binary_op_result_column = pandas_lit(False) - elif op == "ne" and _op_is_between_timedelta_and_numeric( - first_datatype, second_datatype - ): - binary_op_result_column = pandas_lit(True) - elif ( - op in ("truediv", "floordiv") - and isinstance(first_datatype(), TimedeltaType) - and _is_numeric_non_timedelta_type(second_datatype()) - ): - binary_op_result_column = cast( - floor(first_operand / second_operand), LongType() - ) - snowpark_pandas_type = TimedeltaType() - elif ( - op == "mod" - and isinstance(first_datatype(), TimedeltaType) - and _is_numeric_non_timedelta_type(second_datatype()) - ): - binary_op_result_column = ceil( - compute_modulo_between_snowpark_columns( - first_operand, first_datatype(), second_operand, second_datatype() + if ( + self.op == "sub" + and isinstance(self.first_datatype(), TimedeltaType) + and isinstance(self.second_datatype(), TimestampType) + ): + raise TypeError("bad operand type for unary -: 'DatetimeArray'") + + # Raise error for two timedelta or timedelta and null + two_timedeltas_or_timedelta_and_null_error = { + "pow": TypeError("unsupported operand type for **: Timedelta"), + "__or__": TypeError("unsupported operand type for |: Timedelta"), + "__and__": TypeError("unsupported operand type for &: Timedelta"), + "mul": np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined] + np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]")) + ), + } + if ( + self.op in two_timedeltas_or_timedelta_and_null_error + and _op_is_between_two_timedeltas_or_timedelta_and_null( + self.first_datatype(), self.second_datatype() ) - ) - snowpark_pandas_type = TimedeltaType() - elif op in ("add", "sub") and ( - ( - isinstance(first_datatype(), TimedeltaType) - and _is_numeric_non_timedelta_type(second_datatype()) - ) - or ( - _is_numeric_non_timedelta_type(first_datatype()) - and isinstance(second_datatype(), TimedeltaType) - ) - ): - raise TypeError( - "Snowpark pandas does not support addition or subtraction between timedelta values and numeric values." - ) - elif op in ("truediv", "floordiv", "mod") and ( - _is_numeric_non_timedelta_type(first_datatype()) - and isinstance(second_datatype(), TimedeltaType) - ): - raise TypeError( - "Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), or floordiv (//)." - ) - elif op in ( - "add", - "sub", - "truediv", - "floordiv", - "mod", - "gt", - "ge", - "lt", - "le", - "ne", - "eq", - ) and ( - ( - isinstance(first_datatype(), TimedeltaType) - and isinstance(second_datatype(), StringType) - ) - or ( - isinstance(second_datatype(), TimedeltaType) - and isinstance(first_datatype(), StringType) - ) - ): + ): + raise two_timedeltas_or_timedelta_and_null_error[self.op] + + if self.op in ("add", "sub") and ( + ( + isinstance(self.first_datatype(), TimedeltaType) + and _is_numeric_non_timedelta_type(self.second_datatype()) + ) + or ( + _is_numeric_non_timedelta_type(self.first_datatype()) + and isinstance(self.second_datatype(), TimedeltaType) + ) + ): + raise TypeError( + "Snowpark pandas does not support addition or subtraction between timedelta values and numeric values." + ) + + if self.op in ("truediv", "floordiv", "mod") and ( + _is_numeric_non_timedelta_type(self.first_datatype()) + and isinstance(self.second_datatype(), TimedeltaType) + ): + raise TypeError( + "Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), " + "or floordiv (//)." + ) + # TODO(SNOW-1646604): Support these cases. - ErrorMessage.not_implemented( - f"Snowpark pandas does not yet support the operation {op} between timedelta and string" - ) - elif op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and ( - _op_is_between_timedelta_and_numeric(first_datatype, second_datatype) - ): - raise TypeError( - f"Snowpark pandas does not support binary operation {op} between timedelta and a non-timedelta type." - ) - elif op == "floordiv": - binary_op_result_column = floor(first_operand / second_operand) - elif op == "mod": - binary_op_result_column = compute_modulo_between_snowpark_columns( - first_operand, first_datatype(), second_operand, second_datatype() - ) - elif op == "pow": - binary_op_result_column = compute_power_between_snowpark_columns( - first_operand, second_operand - ) - elif op == "__or__": - binary_op_result_column = first_operand | second_operand - elif op == "__and__": - binary_op_result_column = first_operand & second_operand - elif ( - op == "add" - and isinstance(second_datatype(), StringType) - and isinstance(first_datatype(), StringType) - ): - # string/string case (only for add) - binary_op_result_column = concat(first_operand, second_operand) - elif op == "mul" and ( - ( - isinstance(second_datatype(), _IntegralType) - and isinstance(first_datatype(), StringType) - ) - or ( - isinstance(second_datatype(), StringType) - and isinstance(first_datatype(), _IntegralType) + if self.op in ( + "add", + "sub", + "truediv", + "floordiv", + "mod", + "gt", + "ge", + "lt", + "le", + "ne", + "eq", + ) and ( + ( + isinstance(self.first_datatype(), TimedeltaType) + and isinstance(self.second_datatype(), StringType) + ) + or ( + isinstance(self.second_datatype(), TimedeltaType) + and isinstance(self.first_datatype(), StringType) + ) + ): + ErrorMessage.not_implemented( + f"Snowpark pandas does not yet support the operation {self.op} between timedelta and string" + ) + + if self.op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and ( + _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ) + ): + raise TypeError( + f"Snowpark pandas does not support binary operation {self.op} between timedelta and a non-timedelta " + f"type." + ) + + def compute(self) -> SnowparkPandasColumn: + self._check_error() + + self._check_timedelta_with_none() + + if self.result_column is not None: + return self._get_result() + + # Generally, some operators and the data types have to be handled specially to align with pandas + # However, it is difficult to fail early if the arithmetic operator is not compatible + # with the data type, so we just let the server raise exception (e.g. a string minus a string). + + self._custom_compute() + if self.result_column is None: + # If there is no special binary_op_result_column result, it means the operator and + # the data type of the column don't need special handling. Then we get the overloaded + # operator from Snowpark Column class, e.g., __add__ to perform binary operations. + self.result_column = getattr(self.first_operand, f"__{self.op}__")( + self.second_operand + ) + + return self._get_result() + + +class AddOp(BinaryOp): + def _custom_compute(self) -> None: + if isinstance(self.second_datatype(), TimedeltaType) and isinstance( + self.first_datatype(), TimestampType + ): + self.result_column = dateadd("ns", self.second_operand, self.first_operand) + elif isinstance(self.first_datatype(), TimedeltaType) and isinstance( + self.second_datatype(), TimestampType + ): + self.result_column = dateadd("ns", self.first_operand, self.second_operand) + elif isinstance(self.first_datatype(), TimedeltaType) and isinstance( + self.second_datatype(), TimedeltaType + ): + self.result_snowpark_pandas_type = TimedeltaType() + elif isinstance(self.second_datatype(), StringType) and isinstance( + self.first_datatype(), StringType + ): + # string/string case (only for add) + self.result_column = concat(self.first_operand, self.second_operand) + + +class SubOp(BinaryOp): + def _custom_compute(self) -> None: + if isinstance(self.second_datatype(), TimedeltaType) and isinstance( + self.first_datatype(), TimestampType + ): + self.result_column = dateadd( + "ns", -1 * self.second_operand, self.first_operand + ) + elif isinstance(self.first_datatype(), TimedeltaType) and isinstance( + self.second_datatype(), TimedeltaType + ): + self.result_snowpark_pandas_type = TimedeltaType() + elif isinstance(self.first_datatype(), TimestampType) and isinstance( + self.second_datatype(), NullType + ): + # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, + # but it's valid in pandas and returns NULL. + self.result_column = pandas_lit(None) + elif isinstance(self.first_datatype(), NullType) and isinstance( + self.second_datatype(), TimestampType + ): + # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, + # but it's valid in pandas and returns NULL. + self.result_column = pandas_lit(None) + elif isinstance(self.first_datatype(), TimestampType) and isinstance( + self.second_datatype(), TimestampType + ): + ( + self.result_column, + self.result_snowpark_pandas_type, + ) = _compute_subtraction_between_snowpark_timestamp_columns( + first_operand=self.first_operand, + first_datatype=self.first_datatype(), + second_operand=self.second_operand, + second_datatype=self.second_datatype(), + ) + + +class ModOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = compute_modulo_between_snowpark_columns( + self.first_operand, + self.first_datatype(), + self.second_operand, + self.second_datatype(), ) - ): - # string/integer case (only for mul/rmul). - # swap first_operand with second_operand because - # REPEAT(, ) expects to be string - if isinstance(first_datatype(), _IntegralType): - first_operand, second_operand = second_operand, first_operand - - binary_op_result_column = iff( - second_operand > pandas_lit(0), - repeat(first_operand, second_operand), - # Snowflake's repeat doesn't support negative number, - # but pandas will return an empty string - pandas_lit(""), + if _op_is_between_two_timedeltas_or_timedelta_and_null( + self.first_datatype(), self.second_datatype() + ): + self.result_snowpark_pandas_type = TimedeltaType() + elif isinstance( + self.first_datatype(), TimedeltaType + ) and _is_numeric_non_timedelta_type(self.second_datatype()): + self.result_column = ceil(self.result_column) + self.result_snowpark_pandas_type = TimedeltaType() + + +class MulOp(BinaryOp): + def _custom_compute(self) -> None: + if _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ): + self.result_column = cast( + floor(self.first_operand * self.second_operand), LongType() + ) + self.result_snowpark_pandas_type = TimedeltaType() + elif ( + isinstance(self.second_datatype(), _IntegralType) + and isinstance(self.first_datatype(), StringType) + ) or ( + isinstance(self.second_datatype(), StringType) + and isinstance(self.first_datatype(), _IntegralType) + ): + # string/integer case (only for mul/rmul). + # swap first_operand with second_operand because + # REPEAT(, ) expects to be string + if isinstance(self.first_datatype(), _IntegralType): + self.first_operand, self.second_operand = ( + self.second_operand, + self.first_operand, + ) + + self.result_column = iff( + self.second_operand > pandas_lit(0), + repeat(self.first_operand, self.second_operand), + # Snowflake's repeat doesn't support negative number, + # but pandas will return an empty string + pandas_lit(""), + ) + + +class EqOp(BinaryOp): + def _custom_compute(self) -> None: + # For `eq` and `ne`, note that Snowflake will consider 1 equal to + # Timedelta(1) because those two have the same representation in Snowflake, + # so we have to compare types in the client. + if _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ): + self.result_column = pandas_lit(False) + + +class NeOp(BinaryOp): + def _custom_compute(self) -> None: + # For `eq` and `ne`, note that Snowflake will consider 1 equal to + # Timedelta(1) because those two have the same representation in Snowflake, + # so we have to compare types in the client. + if _op_is_between_timedelta_and_numeric( + self.first_datatype, self.second_datatype + ): + self.result_column = pandas_lit(True) + + +class FloordivOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = floor(self.first_operand / self.second_operand) + if isinstance( + self.first_datatype(), TimedeltaType + ) and _is_numeric_non_timedelta_type(self.second_datatype()): + self.result_column = cast(self.result_column, LongType()) + self.result_snowpark_pandas_type = TimedeltaType() + + +class TruedivOp(BinaryOp): + def _custom_compute(self) -> None: + if isinstance( + self.first_datatype(), TimedeltaType + ) and _is_numeric_non_timedelta_type(self.second_datatype()): + self.result_column = cast( + floor(self.first_operand / self.second_operand), LongType() + ) + self.result_snowpark_pandas_type = TimedeltaType() + + +class PowOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = compute_power_between_snowpark_columns( + self.first_operand, self.second_operand ) - elif op == "equal_null": + + +class OrOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = self.first_operand | self.second_operand + + +class AndOp(BinaryOp): + def _custom_compute(self) -> None: + self.result_column = self.first_operand & self.second_operand + + +class EqualNullOp(BinaryOp): + def _custom_compute(self) -> None: # TODO(SNOW-1641716): In Snowpark pandas, generally use this equal_null # with type checking intead of snowflake.snowpark.functions.equal_null. - if not are_equal_types(first_datatype(), second_datatype()): - binary_op_result_column = pandas_lit(False) + if not are_equal_types(self.first_datatype(), self.second_datatype()): + self.result_column = pandas_lit(False) else: - binary_op_result_column = first_operand.equal_null(second_operand) - elif ( - op == "sub" - and isinstance(first_datatype(), TimestampType) - and isinstance(second_datatype(), NullType) - ): - # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, - # but it's valid in pandas and returns NULL. - binary_op_result_column = pandas_lit(None) - elif ( - op == "sub" - and isinstance(first_datatype(), NullType) - and isinstance(second_datatype(), TimestampType) - ): - # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, - # but it's valid in pandas and returns NULL. - binary_op_result_column = pandas_lit(None) - elif ( - op == "sub" - and isinstance(first_datatype(), TimestampType) - and isinstance(second_datatype(), TimestampType) - ): - return _compute_subtraction_between_snowpark_timestamp_columns( - first_operand=first_operand, - first_datatype=first_datatype(), - second_operand=second_operand, - second_datatype=second_datatype(), - ) - # If there is no special binary_op_result_column result, it means the operator and - # the data type of the column don't need special handling. Then we get the overloaded - # operator from Snowpark Column class, e.g., __add__ to perform binary operations. - if binary_op_result_column is None: - binary_op_result_column = getattr(first_operand, f"__{op}__")(second_operand) - - return SnowparkPandasColumn( - snowpark_column=binary_op_result_column, - snowpark_pandas_type=snowpark_pandas_type, - ) + self.result_column = self.first_operand.equal_null(self.second_operand) def are_equal_types(type1: DataType, type2: DataType) -> bool: @@ -644,104 +844,6 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool: return type1 == type2 -def compute_binary_op_between_snowpark_column_and_scalar( - op: str, - first_operand: SnowparkColumn, - datatype: DataTypeGetter, - second_operand: Scalar, -) -> SnowparkPandasColumn: - """ - Compute the binary operation between a Snowpark column and a scalar. - Args: - op: the name of binary operation - first_operand: The SnowparkColumn for lhs - datatype: Callable for Snowpark data type - second_operand: Scalar value - - Returns: - SnowparkPandasColumn for translated pandas operation - """ - - def second_datatype() -> DataType: - return infer_object_type(second_operand) - - return compute_binary_op_between_snowpark_columns( - op, first_operand, datatype, pandas_lit(second_operand), second_datatype - ) - - -def compute_binary_op_between_scalar_and_snowpark_column( - op: str, - first_operand: Scalar, - second_operand: SnowparkColumn, - datatype: DataTypeGetter, -) -> SnowparkPandasColumn: - """ - Compute the binary operation between a scalar and a Snowpark column. - Args: - op: the name of binary operation - first_operand: Scalar value - second_operand: The SnowparkColumn for rhs - datatype: Callable for Snowpark data type - it is not needed. - - Returns: - SnowparkPandasColumn for translated pandas operation - """ - - def first_datatype() -> DataType: - return infer_object_type(first_operand) - - return compute_binary_op_between_snowpark_columns( - op, pandas_lit(first_operand), first_datatype, second_operand, datatype - ) - - -def compute_binary_op_with_fill_value( - op: str, - lhs: SnowparkColumn, - lhs_datatype: DataTypeGetter, - rhs: SnowparkColumn, - rhs_datatype: DataTypeGetter, - fill_value: Scalar, -) -> SnowparkPandasColumn: - """ - Helper method for performing binary operations. - 1. Fills NaN/None values in the lhs and rhs with the given fill_value. - 2. Computes the binary operation expression for lhs rhs. - - fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs. - For instance, with fill_value = 100, - 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value. - result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110 - 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value. - result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103 - 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None. - result = lhs + rhs => None + None => None. - - Args: - op: pandas operation to perform between lhs and rhs - lhs: the lhs SnowparkColumn - lhs_datatype: Callable for Snowpark Datatype for lhs - rhs: the rhs SnowparkColumn - rhs_datatype: Callable for Snowpark Datatype for rhs - fill_value: Fill existing missing (NaN) values, and any new element needed for - successful DataFrame alignment, with this value before computation. - - Returns: - SnowparkPandasColumn for translated pandas operation - """ - lhs_cond, rhs_cond = lhs, rhs - if fill_value is not None: - fill_value_lit = pandas_lit(fill_value) - lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs) - rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs) - - return compute_binary_op_between_snowpark_columns( - op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype - ) - - def merge_label_and_identifier_pairs( sorted_column_labels: list[str], q_frame_sorted: list[tuple[str, str]], diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index f5c6be3b751..b1a2736d120 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -185,11 +185,7 @@ sort_apply_udtf_result_columns_by_pandas_positions, ) from snowflake.snowpark.modin.plugin._internal.binary_op_utils import ( - compute_binary_op_between_scalar_and_snowpark_column, - compute_binary_op_between_snowpark_column_and_scalar, - compute_binary_op_between_snowpark_columns, - compute_binary_op_with_fill_value, - is_binary_op_supported, + BinaryOp, merge_label_and_identifier_pairs, prepare_binop_pairs_between_dataframe_and_dataframe, ) @@ -1855,7 +1851,7 @@ def _binary_op_scalar_rhs( replace_mapping = {} data_column_snowpark_pandas_types = [] for identifier in self._modin_frame.data_column_snowflake_quoted_identifiers: - expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=col(identifier), lhs_datatype=lambda identifier=identifier: self._modin_frame.get_snowflake_type( @@ -1864,7 +1860,7 @@ def _binary_op_scalar_rhs( rhs=pandas_lit(other), rhs_datatype=lambda: infer_object_type(other), fill_value=fill_value, - ) + ).compute() replace_mapping[identifier] = expression data_column_snowpark_pandas_types.append(snowpark_pandas_type) return SnowflakeQueryCompiler( @@ -1915,7 +1911,7 @@ def _binary_op_list_like_rhs_axis_0( replace_mapping = {} snowpark_pandas_types = [] for identifier in new_frame.data_column_snowflake_quoted_identifiers[:-1]: - expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=col(identifier), lhs_datatype=lambda identifier=identifier: new_frame.get_snowflake_type( @@ -1924,7 +1920,7 @@ def _binary_op_list_like_rhs_axis_0( rhs=col(other_identifier), rhs_datatype=lambda: new_frame.get_snowflake_type(other_identifier), fill_value=fill_value, - ) + ).compute() replace_mapping[identifier] = expression snowpark_pandas_types.append(snowpark_pandas_type) @@ -1987,7 +1983,7 @@ def _binary_op_list_like_rhs_axis_1( # rhs is not guaranteed to be a scalar value - it can be a list-like as well. # Convert all list-like objects to a list. rhs_lit = pandas_lit(rhs) if is_scalar(rhs) else pandas_lit(rhs.tolist()) - expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op, lhs=lhs, lhs_datatype=lambda identifier=identifier: self._modin_frame.get_snowflake_type( @@ -1996,7 +1992,7 @@ def _binary_op_list_like_rhs_axis_1( rhs=rhs_lit, rhs_datatype=lambda rhs=rhs: infer_object_type(rhs), fill_value=fill_value, - ) + ).compute() replace_mapping[identifier] = expression snowpark_pandas_types.append(snowpark_pandas_type) @@ -2057,7 +2053,7 @@ def binary_op( # match pandas documentation; hence it is omitted in the Snowpark pandas implementation. raise ValueError("Only scalars can be used as fill_value.") - if not is_binary_op_supported(op): + if not BinaryOp.is_binary_op_supported(op): ErrorMessage.not_implemented( f"Snowpark pandas doesn't yet support '{op}' binary operation" ) @@ -2122,7 +2118,7 @@ def binary_op( )[0] # add new column with result as unnamed - new_column_expr, snowpark_pandas_type = compute_binary_op_with_fill_value( + new_column_expr, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=col(lhs_quoted_identifier), lhs_datatype=lambda: aligned_frame.get_snowflake_type( @@ -2133,7 +2129,7 @@ def binary_op( rhs_quoted_identifier ), fill_value=fill_value, - ) + ).compute() # name is dropped when names of series differ. A dropped name is using unnamed series label. new_column_name = ( @@ -10767,7 +10763,7 @@ def _make_discrete_difference_expression( snowpark_pandas_type=None, ) else: - return compute_binary_op_between_snowpark_columns( + return BinaryOp.create( "sub", col(snowflake_quoted_identifier), lambda: column_datatype, @@ -10779,7 +10775,7 @@ def _make_discrete_difference_expression( ) ), lambda: column_datatype, - ) + ).compute() else: # periods is the number of columns to *go back*. @@ -10828,13 +10824,13 @@ def _make_discrete_difference_expression( col1 = cast(col1, IntegerType()) if isinstance(col2_dtype, BooleanType): col2 = cast(col2, IntegerType()) - return compute_binary_op_between_snowpark_columns( + return BinaryOp.create( "sub", col1, lambda: col1_dtype, col2, lambda: col2_dtype, - ) + ).compute() def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler": """ @@ -14432,7 +14428,7 @@ def _binary_op_between_dataframe_and_series_along_axis_0( ) ) - # Lazify type map here for calling compute_binary_op_between_snowpark_columns. + # Lazify type map here for calling binaryOp.compute. def create_lazy_type_functions( identifiers: list[str], ) -> list[DataTypeGetter]: @@ -14462,12 +14458,9 @@ def create_lazy_type_functions( replace_mapping = {} snowpark_pandas_types = [] for left, left_datatype in zip(left_result_data_identifiers, left_datatypes): - ( - expression, - snowpark_pandas_type, - ) = compute_binary_op_between_snowpark_columns( + (expression, snowpark_pandas_type,) = BinaryOp.create( op, col(left), left_datatype, col(right), right_datatype - ) + ).compute() snowpark_pandas_types.append(snowpark_pandas_type) replace_mapping[left] = expression update_result = joined_frame.result_frame.update_snowflake_quoted_identifiers_with_expressions( @@ -14726,14 +14719,14 @@ def infer_sorted_column_labels( replace_mapping = {} data_column_snowpark_pandas_types = [] for p in left_right_pairs: - result_expression, snowpark_pandas_type = compute_binary_op_with_fill_value( + result_expression, snowpark_pandas_type = BinaryOp.create_with_fill_value( op=op, lhs=p.lhs, lhs_datatype=p.lhs_datatype, rhs=p.rhs, rhs_datatype=p.rhs_datatype, fill_value=fill_value, - ) + ).compute() replace_mapping[p.identifier] = result_expression data_column_snowpark_pandas_types.append(snowpark_pandas_type) # Create restricted frame with only combined / replaced labels. @@ -15000,19 +14993,19 @@ def infer_sorted_column_labels( snowpark_pandas_labels = [] for label, identifier in overlapping_pairs: expression, new_type = ( - compute_binary_op_between_scalar_and_snowpark_column( + BinaryOp.create_with_lhs_scalar( op, series.loc[label], col(identifier), datatype_getters[identifier], - ) + ).compute() if squeeze_self - else compute_binary_op_between_snowpark_column_and_scalar( + else BinaryOp.create_with_rhs_scalar( op, col(identifier), datatype_getters[identifier], series.loc[label], - ) + ).compute() ) snowpark_pandas_labels.append(new_type) replace_mapping[identifier] = expression @@ -17465,9 +17458,11 @@ def equals( ) replace_mapping = { - p.identifier: compute_binary_op_between_snowpark_columns( + p.identifier: BinaryOp.create( "equal_null", p.lhs, p.lhs_datatype, p.rhs, p.rhs_datatype - ).snowpark_column + ) + .compute() + .snowpark_column for p in left_right_pairs } @@ -17995,7 +17990,7 @@ def compare( right_identifier = result_column_mapper.right_quoted_identifiers_map[ right_identifier ] - op_result = compute_binary_op_between_snowpark_columns( + op_result = BinaryOp.create( op="equal_null", first_operand=col(left_identifier), first_datatype=functools.partial( @@ -18005,7 +18000,7 @@ def compare( second_datatype=functools.partial( lambda col: result_frame.get_snowflake_type(col), right_identifier ), - ) + ).compute() binary_op_result = binary_op_result.append_column( str(left_pandas_label) + "_comparison_result", op_result.snowpark_column, @@ -18116,19 +18111,23 @@ def compare( right_identifier ] - cols_equal = compute_binary_op_between_snowpark_columns( - op="equal_null", - first_operand=col(left_mappped_identifier), - first_datatype=functools.partial( - lambda col: result_frame.get_snowflake_type(col), - left_mappped_identifier, - ), - second_operand=col(right_mapped_identifier), - second_datatype=functools.partial( - lambda col: result_frame.get_snowflake_type(col), - right_mapped_identifier, - ), - ).snowpark_column + cols_equal = ( + BinaryOp.create( + op="equal_null", + first_operand=col(left_mappped_identifier), + first_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), + left_mappped_identifier, + ), + second_operand=col(right_mapped_identifier), + second_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), + right_mapped_identifier, + ), + ) + .compute() + .snowpark_column + ) # Add a column containing the values from `self`, but replace # matching values with null. From 64c0e34e74a6f7e71b1c31b3555c1c7eecc3ed55 Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Wed, 11 Sep 2024 13:26:17 -0700 Subject: [PATCH 2/5] SNOW-1657456 Improved to persist the original timezone offset for TIMESTAMP_TZ type (#2260) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1657456 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. Please write a short description of how your code change solves the related issue. The problem was that to_pandas use Snowflake Python connector to read query result into a pandas dataframe. However, the connector will convert TIMESTAMP_TZ result to timestamps using the local session timezone which does not match with the pandas behavior in our scenario. The goal is to preserve the original timezone offsets for timestamp_tz. The idea here is to convert timestamp_tz to string before calling to_pandas and then cast them back to datetime64tz on the native pandas result. --- CHANGELOG.md | 5 + .../snowpark/modin/pandas/general.py | 14 +- .../snowpark/modin/plugin/_internal/utils.py | 79 +++++++- .../modin/plugin/docstrings/series.py | 2 +- .../modin/plugin/extensions/datetime_index.py | 2 +- tests/integ/modin/frame/test_dtypes.py | 4 +- tests/integ/modin/index/conftest.py | 8 +- tests/integ/modin/series/test_astype.py | 32 ++-- tests/integ/modin/test_dtype_mapping.py | 10 +- .../integ/modin/test_from_pandas_to_pandas.py | 177 +++++++----------- tests/integ/modin/tools/test_to_datetime.py | 2 +- 11 files changed, 176 insertions(+), 159 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b119e75573a..92b240a7d6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## 1.23.0 (TBD) +### Snowpark pandas API Updates + +#### Improvements + +- Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type. ## 1.22.0 (2024-09-10) diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index d5d158373de..2ca9d8e5b83 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -1047,7 +1047,7 @@ def unique(values) -> np.ndarray: >>> pd.unique([pd.Timestamp('2016-01-01', tz='US/Eastern') ... for _ in range(3)]) - array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')], + array([Timestamp('2016-01-01 00:00:00-0500', tz='UTC-05:00')], dtype=object) >>> pd.unique([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]) @@ -1750,35 +1750,35 @@ def to_datetime( DatetimeIndex(['2018-10-26 12:00:00', '2018-10-26 13:00:15'], dtype='datetime64[ns]', freq=None) >>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500']) - DatetimeIndex(['2018-10-26 10:00:00-07:00', '2018-10-26 11:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 12:00:00-05:00', '2018-10-26 13:00:00-05:00'], dtype='datetime64[ns, UTC-05:00]', freq=None) - Use right format to convert to timezone-aware type (Note that when call Snowpark pandas API to_pandas() the timezone-aware output will always be converted to session timezone): >>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500'], format="%Y-%m-%d %H:%M:%S %z") - DatetimeIndex(['2018-10-26 10:00:00-07:00', '2018-10-26 11:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 12:00:00-05:00', '2018-10-26 13:00:00-05:00'], dtype='datetime64[ns, UTC-05:00]', freq=None) - Timezone-aware inputs *with mixed time offsets* (for example issued from a timezone with daylight savings, such as Europe/Paris): >>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100']) - DatetimeIndex(['2020-10-24 17:00:00-07:00', '2020-10-24 20:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex([2020-10-25 02:00:00+02:00, 2020-10-25 04:00:00+01:00], dtype='object', freq=None) >>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100'], format="%Y-%m-%d %H:%M:%S %z") - DatetimeIndex(['2020-10-24 17:00:00-07:00', '2020-10-24 20:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex([2020-10-25 02:00:00+02:00, 2020-10-25 04:00:00+01:00], dtype='object', freq=None) Setting ``utc=True`` makes sure always convert to timezone-aware outputs: - Timezone-naive inputs are *localized* based on the session timezone >>> pd.to_datetime(['2018-10-26 12:00', '2018-10-26 13:00'], utc=True) - DatetimeIndex(['2018-10-26 05:00:00-07:00', '2018-10-26 06:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 12:00:00+00:00', '2018-10-26 13:00:00+00:00'], dtype='datetime64[ns, UTC]', freq=None) - Timezone-aware inputs are *converted* to session timezone >>> pd.to_datetime(['2018-10-26 12:00:00 -0530', '2018-10-26 12:00:00 -0500'], ... utc=True) - DatetimeIndex(['2018-10-26 10:30:00-07:00', '2018-10-26 10:00:00-07:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2018-10-26 17:30:00+00:00', '2018-10-26 17:00:00+00:00'], dtype='datetime64[ns, UTC]', freq=None) """ # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py raise_if_native_pandas_objects(arg) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/utils.py b/src/snowflake/snowpark/modin/plugin/_internal/utils.py index 5656bbfb14a..9f01954ab2c 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/utils.py @@ -41,6 +41,7 @@ mean, min as min_, sum as sum_, + to_char, to_timestamp_ntz, to_timestamp_tz, typeof, @@ -75,6 +76,8 @@ StringType, StructField, StructType, + TimestampTimeZone, + TimestampType, VariantType, _FractionalType, ) @@ -1289,14 +1292,23 @@ def snowpark_to_pandas_helper( ) -> Union[native_pd.Index, native_pd.DataFrame]: """ The helper function retrieves a pandas dataframe from an OrderedDataFrame. Performs necessary type - conversions for variant types on the client. This function issues 2 queries, one metadata query - to retrieve the schema and one query to retrieve the data values. + conversions including + 1. For VARIANT types, OrderedDataFrame.to_pandas may convert datetime like types to string. So we add one `typeof` + column for each variant column and use that metadata to convert datetime like types back to their original types. + 2. For TIMESTAMP_TZ type, OrderedDataFrame.to_pandas will convert them into the local session timezone and lose the + original timezone. So we cast TIMESTAMP_TZ columns to string first and then convert them back after to_pandas to + preserve the original timezone. Note that the actual timezone will be lost in Snowflake backend but only the offset + preserved. + 3. For Timedelta columns, since currently we represent the values using integers, here we need to explicitly cast + them back to Timedelta. Args: frame: The internal frame to convert to pandas Dataframe (or Index if index_only is true) index_only: if true, only turn the index columns into a pandas Index - statement_params: Dictionary of statement level parameters to be passed to conversion function of ordered dataframe abstraction. - kwargs: Additional keyword-only args to pass to internal `to_pandas` conversion for orderded dataframe abstraction. + statement_params: Dictionary of statement level parameters to be passed to conversion function of ordered + dataframe abstraction. + kwargs: Additional keyword-only args to pass to internal `to_pandas` conversion for ordered dataframe + abstraction. Returns: pandas dataframe @@ -1365,7 +1377,7 @@ def snowpark_to_pandas_helper( ) variant_type_identifiers = list(map(lambda t: t[0], variant_type_columns_info)) - # Step 3: Create for each variant type column a separate type column (append at end), and retrieve data values + # Step 3.1: Create for each variant type column a separate type column (append at end), and retrieve data values # (and types for variant type columns). variant_type_typeof_identifiers = ( ordered_dataframe.generate_snowflake_quoted_identifiers( @@ -1384,10 +1396,36 @@ def snowpark_to_pandas_helper( [typeof(col(id)) for id in variant_type_identifiers], ) + # Step 3.2: cast timestamp_tz to string to preserve their original timezone offsets + timestamp_tz_identifiers = [ + info[0] + for info in columns_info + if info[1] == TimestampType(TimestampTimeZone.TZ) + ] + timestamp_tz_str_identifiers = ( + ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[ + f"{unquote_name_if_quoted(id)}_str" for id in timestamp_tz_identifiers + ], + excluded=column_identifiers, + ) + ) + if len(timestamp_tz_identifiers): + ordered_dataframe = append_columns( + ordered_dataframe, + timestamp_tz_str_identifiers, + [ + to_char(col(id), format="YYYY-MM-DD HH24:MI:SS.FF9 TZHTZM") + for id in timestamp_tz_identifiers + ], + ) + # ensure that snowpark_df has unique identifiers, so the native pandas DataFrame object created here # also does have unique column names which is a prerequisite for the post-processing logic following. assert is_duplicate_free( - column_identifiers + variant_type_typeof_identifiers + column_identifiers + + variant_type_typeof_identifiers + + timestamp_tz_str_identifiers ), "Snowpark DataFrame to convert must have unique column identifiers" pandas_df = ordered_dataframe.to_pandas(statement_params=statement_params, **kwargs) @@ -1400,7 +1438,9 @@ def snowpark_to_pandas_helper( # Step 3a: post-process variant type columns, if any exist. id_to_label_mapping = dict( zip( - column_identifiers + variant_type_typeof_identifiers, + column_identifiers + + variant_type_typeof_identifiers + + timestamp_tz_str_identifiers, pandas_df.columns, ) ) @@ -1439,6 +1479,25 @@ def convert_variant_type_to_pandas(row: native_pd.Series) -> Any: id_to_label_mapping[quoted_name] ].apply(lambda value: None if value is None else json.loads(value)) + # Convert timestamp_tz in string back to datetime64tz. + if any( + dtype == TimestampType(TimestampTimeZone.TZ) for (_, dtype) in columns_info + ): + id_to_label_mapping = dict( + zip( + column_identifiers + + variant_type_typeof_identifiers + + timestamp_tz_str_identifiers, + pandas_df.columns, + ) + ) + for ts_id, ts_str_id in zip( + timestamp_tz_identifiers, timestamp_tz_str_identifiers + ): + pandas_df[id_to_label_mapping[ts_id]] = native_pd.to_datetime( + pandas_df[id_to_label_mapping[ts_str_id]] + ) + # Step 5. Return the original amount of columns by stripping any typeof(...) columns appended if # schema contained VariantType. downcast_pandas_df = pandas_df[pandas_df.columns[: len(columns_info)]] @@ -1493,7 +1552,11 @@ def convert_str_to_timedelta(x: str) -> pd.Timedelta: # multiple timezones. So here we cast the index to the index_type when ret = pd.Index(...) above cannot # figure out a non-object dtype. Note that the index_type is a logical type may not be 100% accurate. if is_object_dtype(ret.dtype) and not is_object_dtype(index_type): - ret = ret.astype(index_type) + # TODO: SNOW-1657460 fix index_type for timestamp_tz + try: + ret = ret.astype(index_type) + except ValueError: # e.g., Tz-aware datetime.datetime cannot be converted to datetime64 + pass return ret # to_pandas() does not preserve the index information and will just return a diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py index 1d351fd67af..9e4ebd4d257 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py @@ -3428,7 +3428,7 @@ def unique(): >>> pd.Series([pd.Timestamp('2016-01-01', tz='US/Eastern') ... for _ in range(3)]).unique() - array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')], + array([Timestamp('2016-01-01 00:00:00-0500', tz='UTC-05:00')], dtype=object) """ diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index 7be7adb54c1..df136af1a34 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -229,7 +229,7 @@ def __init__( -------- >>> idx = pd.DatetimeIndex(["1/1/2020 10:00:00+00:00", "2/1/2020 11:00:00+00:00"], tz="America/Los_Angeles") >>> idx - DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, UTC-08:00]', freq=None) """ # DatetimeIndex is already initialized in __new__ method. We keep this method # only for docstring generation. diff --git a/tests/integ/modin/frame/test_dtypes.py b/tests/integ/modin/frame/test_dtypes.py index c3773bdd6de..b078b31f6c5 100644 --- a/tests/integ/modin/frame/test_dtypes.py +++ b/tests/integ/modin/frame/test_dtypes.py @@ -351,7 +351,7 @@ def test_insert_multiindex_multi_label(label1, label2): native_pd.Timestamp(1513393355, unit="s", tz="US/Pacific"), ], "datetime64[ns, America/Los_Angeles]", - "datetime64[ns, America/Los_Angeles]", + "datetime64[ns, UTC-08:00]", "datetime64[ns]", ), ( @@ -372,7 +372,7 @@ def test_insert_multiindex_multi_label(label1, label2): native_pd.Timestamp(1513393355, unit="s", tz="US/Pacific"), ], "object", - "datetime64[ns, America/Los_Angeles]", + "datetime64[ns, UTC-08:00]", "datetime64[ns]", ), ], diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index 3c6362dd83c..84454fc4a27 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -33,11 +33,11 @@ native_pd.Index(["a", "b", "c", "d"]), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"], - tz="America/Los_Angeles", + tz="UTC-08:00", ), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+05:00", "2020-02-01 11:00:00+05:00"], - tz="America/Los_Angeles", + tz="UTC", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), native_pd.TimedeltaIndex(["0 days", "1 days", "3 days"]), @@ -55,11 +55,11 @@ native_pd.Index(["a", "b", 1, 2, None, "a", 2], name="mixed index"), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"], - tz="America/Los_Angeles", + tz="UTC", ), native_pd.DatetimeIndex( ["2020-01-01 10:00:00+00:00", "2020-01-01 10:00:00+00:00"], - tz="America/Los_Angeles", + tz="UTC-08:00", ), ] diff --git a/tests/integ/modin/series/test_astype.py b/tests/integ/modin/series/test_astype.py index 5bbce79b01b..030416d65c5 100644 --- a/tests/integ/modin/series/test_astype.py +++ b/tests/integ/modin/series/test_astype.py @@ -173,6 +173,11 @@ def test_astype_basic(from_dtype, to_dtype): ) def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): to_dtype = f"datetime64[ns, {to_tz}]" + offset_map = { + "UTC": "UTC", + "Asia/Tokyo": "UTC+09:00", + "America/Los_Angeles": "UTC-08:00", + } seed = ( [True, False, False, True] # if isinstance(from_dtype, BooleanDtype) @@ -189,23 +194,22 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): native_pd.Series(seed, dtype=from_dtype).astype(to_dtype) elif isinstance(from_dtype, StringDtype) or from_dtype is str: # Snowpark pandas use Snowflake auto format detection and the behavior can be different from native pandas - # to_pandas always convert timezone to the local timezone today, i.e., "America/Los_angeles" with SqlCounter(query_count=1): assert_snowpark_pandas_equal_to_pandas( pd.Series(seed, dtype=from_dtype).astype(to_dtype), native_pd.Series( [ native_pd.Timestamp("1970-01-01 00:00:00", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), native_pd.Timestamp("1970-01-01 00:00:01", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), native_pd.Timestamp("1970-01-01 00:00:02", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), native_pd.Timestamp("1970-01-01 00:00:03", tz="UTC").tz_convert( - "America/Los_Angeles" + offset_map[to_tz] ), ] ), @@ -251,15 +255,15 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): ): native_pd.Series(seed, dtype=from_dtype).astype(to_dtype) expected_to_pandas = ( - native_pd.Series(seed, dtype=from_dtype).dt.tz_localize("UTC") - # Snowpark pandas to_pandas() will convert timestamp_tz to default local timezone - .dt.tz_convert("America/Los_Angeles") + native_pd.Series(seed, dtype=from_dtype) + .dt.tz_localize("UTC") + .dt.tz_convert(offset_map[to_tz]) ) else: expected_to_pandas = ( - native_pd.Series(seed, dtype=from_dtype).astype(to_dtype) - # Snowpark pandas to_pandas() will convert timestamp_tz to default local timezone - .dt.tz_convert("America/Los_Angeles") + native_pd.Series(seed, dtype=from_dtype) + .astype(to_dtype) + .dt.tz_convert(offset_map[to_tz]) ) assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( s, @@ -392,11 +396,7 @@ def test_python_datetime_astype_DatetimeTZDtype(seed): with SqlCounter(query_count=1): snow = s.astype(to_dtype) assert snow.dtype == np.dtype(" from_pandas => TIMESTAMP_TZ(any_tz) => to_pandas => DatetimeTZDtype(session_tz) - # - # Note that python connector will convert any TIMESTAMP_TZ to DatetimeTZDtype with the current session/statement - # timezone, e.g., 1969-12-31 19:00:00 -0500 will be converted to 1970-00-01 00:00:00 in UTC if the session/statement - # parameter TIMEZONE = 'UTC' - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - try: - session.sql(f"alter session set timezone = '{timezone}'").collect() - - def get_series_with_tz(tz): - return ( - native_pd.Series([1] * 3) - .astype("int64") - .astype(f"datetime64[ns, {tz}]") - ) +@sql_count_checker(query_count=1) +def test_from_to_pandas_datetime64_timezone_support(): + def get_series_with_tz(tz): + return native_pd.Series([1] * 3).astype("int64").astype(f"datetime64[ns, {tz}]") - # same timestamps representing in different time zone - test_data_columns = { - "utc": get_series_with_tz("UTC"), - "pacific": get_series_with_tz("US/Pacific"), - "tokyo": get_series_with_tz("Asia/Tokyo"), - } + # same timestamps representing in different time zone + test_data_columns = { + "utc": get_series_with_tz("UTC"), + "pacific": get_series_with_tz("US/Pacific"), + "tokyo": get_series_with_tz("Asia/Tokyo"), + } - # expected to_pandas dataframe's timezone is controlled by session/statement parameter TIMEZONE - expected_to_pandas = native_pd.DataFrame( - { - series: test_data_columns[series].dt.tz_convert(timezone) - for series in test_data_columns - } - ) - assert_snowpark_pandas_equal_to_pandas( - pd.DataFrame(test_data_columns), - expected_to_pandas, - # configure different timezones to to_pandas and verify the timestamps are converted correctly - statement_params={"timezone": timezone}, - ) - finally: - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - session.sql("alter session unset timezone").collect() + expected_data_columns = { + "utc": get_series_with_tz("UTC"), + "pacific": get_series_with_tz("UTC-08:00"), + "tokyo": get_series_with_tz("UTC+09:00"), + } + # expected to_pandas dataframe's timezone is controlled by session/statement parameter TIMEZONE + expected_to_pandas = native_pd.DataFrame(expected_data_columns) + assert_snowpark_pandas_equal_to_pandas( + pd.DataFrame(test_data_columns), + expected_to_pandas, + ) -@pytest.mark.parametrize("timezone", ["UTC", "US/Pacific", "US/Eastern"]) -@sql_count_checker(query_count=3) -def test_from_to_pandas_datetime64_multi_timezone_current_behavior(session, timezone): - try: - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - session.sql(f"alter session set timezone = '{timezone}'").collect() - - # This test also verifies the current behaviors of to_pandas() for datetime with no tz, same tz, or multi tz: - # no tz => TIMESTAMP_NTZ - # same tz => TIMESTAMP_TZ - # multi tz => TIMESTAMP_NTZ - multi_tz_data = ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-07:00"] - test_data_columns = { - "no tz": native_pd.to_datetime( - native_pd.Series(["2019-05-21 12:00:00", "2019-05-21 12:15:00"]) - ), # dtype = datetime64[ns] - "same tz": native_pd.to_datetime( - native_pd.Series( - ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-06:00"] - ) - ), # dtype = datetime64[ns, tz] - "multi tz": native_pd.to_datetime( - native_pd.Series(multi_tz_data) - ), # dtype = object and value type is Python datetime - } +@sql_count_checker(query_count=1) +def test_from_to_pandas_datetime64_multi_timezone_current_behavior(): + # This test also verifies the current behaviors of to_pandas() for datetime with no tz, same tz, or multi tz: + # no tz => TIMESTAMP_NTZ + # same tz => TIMESTAMP_TZ + # multi tz => TIMESTAMP_TZ + multi_tz_data = ["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-07:00"] + test_data_columns = { + "no tz": native_pd.to_datetime( + native_pd.Series(["2019-05-21 12:00:00", "2019-05-21 12:15:00"]) + ), # dtype = datetime64[ns] + "same tz": native_pd.to_datetime( + native_pd.Series(["2019-05-21 12:00:00-06:00", "2019-05-21 12:15:00-06:00"]) + ), # dtype = datetime64[ns, tz] + "multi tz": native_pd.to_datetime( + native_pd.Series(multi_tz_data) + ), # dtype = object and value type is Python datetime + } - expected_to_pandas = native_pd.DataFrame( - { - "no tz": test_data_columns["no tz"], # dtype = datetime64[ns] - "same tz": test_data_columns["same tz"].dt.tz_convert( - timezone - ), # dtype = datetime64[ns, tz] - "multi tz": native_pd.Series( - [ - native_pd.to_datetime(t).tz_convert(timezone) - for t in multi_tz_data - ] - ), - } - ) + expected_to_pandas = native_pd.DataFrame(test_data_columns) - test_df = native_pd.DataFrame(test_data_columns) - # dtype checks for each series - no_tz_dtype = test_df.dtypes["no tz"] - assert is_datetime64_any_dtype(no_tz_dtype) and not isinstance( - no_tz_dtype, DatetimeTZDtype - ) - same_tz_dtype = test_df.dtypes["same tz"] - assert is_datetime64_any_dtype(same_tz_dtype) and isinstance( - same_tz_dtype, DatetimeTZDtype - ) - multi_tz_dtype = test_df.dtypes["multi tz"] - assert ( - not is_datetime64_any_dtype(multi_tz_dtype) - and not isinstance(multi_tz_dtype, DatetimeTZDtype) - and str(multi_tz_dtype) == "object" - ) - # sample value - assert isinstance(test_df["multi tz"][0], datetime.datetime) - assert test_df["multi tz"][0].tzinfo is not None - assert_snowpark_pandas_equal_to_pandas( - pd.DataFrame(test_df), - expected_to_pandas, - statement_params={"timezone": timezone}, - ) - finally: - # TODO: SNOW-871210 no need session parameter change once the bug is fixed - session.sql("alter session unset timezone").collect() + test_df = native_pd.DataFrame(test_data_columns) + # dtype checks for each series + no_tz_dtype = test_df.dtypes["no tz"] + assert is_datetime64_any_dtype(no_tz_dtype) and not isinstance( + no_tz_dtype, DatetimeTZDtype + ) + same_tz_dtype = test_df.dtypes["same tz"] + assert is_datetime64_any_dtype(same_tz_dtype) and isinstance( + same_tz_dtype, DatetimeTZDtype + ) + multi_tz_dtype = test_df.dtypes["multi tz"] + assert ( + not is_datetime64_any_dtype(multi_tz_dtype) + and not isinstance(multi_tz_dtype, DatetimeTZDtype) + and str(multi_tz_dtype) == "object" + ) + # sample value + assert isinstance(test_df["multi tz"][0], datetime.datetime) + assert test_df["multi tz"][0].tzinfo is not None + assert_snowpark_pandas_equal_to_pandas( + pd.DataFrame(test_df), + expected_to_pandas, + ) @sql_count_checker(query_count=1) diff --git a/tests/integ/modin/tools/test_to_datetime.py b/tests/integ/modin/tools/test_to_datetime.py index 1ea3445d15a..df11e6afb80 100644 --- a/tests/integ/modin/tools/test_to_datetime.py +++ b/tests/integ/modin/tools/test_to_datetime.py @@ -565,7 +565,7 @@ def test_to_datetime_mixed_datetime_and_string(self): assert_index_equal(res, expected) # Set utc=True to make sure timezone aware in to_datetime res = to_datetime(pd.Index(["2020-01-01 17:00:00 -0100", d2]), utc=True) - expected = pd.DatetimeIndex([d1, d2]) + expected = pd.DatetimeIndex([d1, d2], tz="UTC") assert_index_equal(res, expected) @pytest.mark.parametrize( From 62861f81197d8b69ed0cb5ea0408ae73edf86ce9 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 13:32:09 -0700 Subject: [PATCH 3/5] SNOW-1624166 add telemetry for compilation stage (#2250) --- .../_internal/compiler/plan_compiler.py | 66 ++++++++++++++++++- .../_internal/compiler/telemetry_constants.py | 15 +++++ src/snowflake/snowpark/_internal/telemetry.py | 23 +++++++ tests/integ/test_telemetry.py | 57 ++++++++++++++++ 4 files changed, 158 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index bef53f0f389..211b66820ec 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -3,8 +3,12 @@ # import copy +import time from typing import Dict, List +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + get_complexity_score, +) from snowflake.snowpark._internal.analyzer.snowflake_plan import ( PlanQueryType, Query, @@ -12,12 +16,18 @@ ) from snowflake.snowpark._internal.analyzer.snowflake_plan_node import LogicalPlan from snowflake.snowpark._internal.compiler.large_query_breakdown import ( + COMPLEXITY_SCORE_LOWER_BOUND, + COMPLEXITY_SCORE_UPPER_BOUND, LargeQueryBreakdown, ) from snowflake.snowpark._internal.compiler.repeated_subquery_elimination import ( RepeatedSubqueryElimination, ) +from snowflake.snowpark._internal.compiler.telemetry_constants import ( + CompilationStageTelemetryField, +) from snowflake.snowpark._internal.compiler.utils import create_query_generator +from snowflake.snowpark._internal.telemetry import TelemetryField from snowflake.snowpark.mock._connection import MockServerConnection @@ -68,24 +78,74 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]: if self.should_start_query_compilation(): # preparation for compilation # 1. make a copy of the original plan + start_time = time.time() + complexity_score_before_compilation = get_complexity_score( + self._plan.cumulative_node_complexity + ) logical_plans: List[LogicalPlan] = [copy.deepcopy(self._plan)] + deep_copy_end_time = time.time() + # 2. create a code generator with the original plan query_generator = create_query_generator(self._plan) - # apply each optimizations if needed + # 3. apply each optimizations if needed + # CTE optimization + cte_start_time = time.time() if self._plan.session.cte_optimization_enabled: repeated_subquery_eliminator = RepeatedSubqueryElimination( logical_plans, query_generator ) logical_plans = repeated_subquery_eliminator.apply() + + cte_end_time = time.time() + complexity_scores_after_cte = [ + get_complexity_score(logical_plan.cumulative_node_complexity) + for logical_plan in logical_plans + ] + + # Large query breakdown if self._plan.session.large_query_breakdown_enabled: large_query_breakdown = LargeQueryBreakdown( self._plan.session, query_generator, logical_plans ) logical_plans = large_query_breakdown.apply() - # do a final pass of code generation - return query_generator.generate_queries(logical_plans) + large_query_breakdown_end_time = time.time() + complexity_scores_after_large_query_breakdown = [ + get_complexity_score(logical_plan.cumulative_node_complexity) + for logical_plan in logical_plans + ] + + # 4. do a final pass of code generation + queries = query_generator.generate_queries(logical_plans) + + # log telemetry data + deep_copy_time = deep_copy_end_time - start_time + cte_time = cte_end_time - cte_start_time + large_query_breakdown_time = large_query_breakdown_end_time - cte_end_time + total_time = time.time() - start_time + session = self._plan.session + summary_value = { + TelemetryField.CTE_OPTIMIZATION_ENABLED.value: session.cte_optimization_enabled, + TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: session.large_query_breakdown_enabled, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: ( + COMPLEXITY_SCORE_LOWER_BOUND, + COMPLEXITY_SCORE_UPPER_BOUND, + ), + CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time, + CompilationStageTelemetryField.TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN.value: large_query_breakdown_time, + CompilationStageTelemetryField.COMPLEXITY_SCORE_BEFORE_COMPILATION.value: complexity_score_before_compilation, + CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION.value: complexity_scores_after_cte, + CompilationStageTelemetryField.COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN.value: complexity_scores_after_large_query_breakdown, + } + session._conn._telemetry_client.send_query_compilation_summary_telemetry( + session_id=session.session_id, + plan_uuid=self._plan.uuid, + compilation_stage_summary=summary_value, + ) + return queries else: final_plan = self._plan if self._plan.session.cte_optimization_enabled: diff --git a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py index 3c1f0d4fc5d..223b6a1326f 100644 --- a/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py +++ b/src/snowflake/snowpark/_internal/compiler/telemetry_constants.py @@ -6,10 +6,25 @@ class CompilationStageTelemetryField(Enum): + # types TYPE_LARGE_QUERY_BREAKDOWN_OPTIMIZATION_SKIPPED = ( "snowpark_large_query_breakdown_optimization_skipped" ) + TYPE_COMPILATION_STAGE_STATISTICS = "snowpark_compilation_stage_statistics" + + # keys KEY_REASON = "reason" + PLAN_UUID = "plan_uuid" + TIME_TAKEN_FOR_COMPILATION = "time_taken_for_compilation_sec" + TIME_TAKEN_FOR_DEEP_COPY_PLAN = "time_taken_for_deep_copy_plan_sec" + TIME_TAKEN_FOR_CTE_OPTIMIZATION = "time_taken_for_cte_optimization_sec" + TIME_TAKEN_FOR_LARGE_QUERY_BREAKDOWN = "time_taken_for_large_query_breakdown_sec" + COMPLEXITY_SCORE_BOUNDS = "complexity_score_bounds" + COMPLEXITY_SCORE_BEFORE_COMPILATION = "complexity_score_before_compilation" + COMPLEXITY_SCORE_AFTER_CTE_OPTIMIZATION = "complexity_score_after_cte_optimization" + COMPLEXITY_SCORE_AFTER_LARGE_QUERY_BREAKDOWN = ( + "complexity_score_after_large_query_breakdown" + ) class SkipLargeQueryBreakdownCategory(Enum): diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 05488398d16..8b9ef2acccb 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -168,6 +168,11 @@ def wrap(*args, **kwargs): ]._session.sql_simplifier_enabled try: api_calls[0][TelemetryField.QUERY_PLAN_HEIGHT.value] = plan.plan_height + # The uuid for df._select_statement can be different from df._plan. Since plan + # can take both values, we cannot use plan.uuid. We always use df._plan.uuid + # to track the queries. + uuid = args[0]._plan.uuid + api_calls[0][CompilationStageTelemetryField.PLAN_UUID.value] = uuid api_calls[0][ TelemetryField.QUERY_PLAN_NUM_DUPLICATE_NODES.value ] = plan.num_duplicate_nodes @@ -428,6 +433,24 @@ def send_large_query_breakdown_telemetry( } self.send(message) + def send_query_compilation_summary_telemetry( + self, + session_id: int, + plan_uuid: str, + compilation_stage_summary: Dict[str, Any], + ) -> None: + message = { + **self._create_basic_telemetry_data( + CompilationStageTelemetryField.TYPE_COMPILATION_STAGE_STATISTICS.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + CompilationStageTelemetryField.PLAN_UUID.value: plan_uuid, + **compilation_stage_summary, + }, + } + self.send(message) + def send_large_query_optimization_skipped_telemetry( self, session_id: int, reason: str ) -> None: diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index bcfa2cfa512..7aaa5c9e5dd 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -5,6 +5,7 @@ import decimal import sys +import uuid from functools import partial from typing import Any, Dict, Tuple @@ -599,6 +600,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -621,6 +623,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -643,6 +646,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -665,6 +669,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -687,6 +692,7 @@ def test_execute_queries_api_calls(session, sql_simplifier_enabled): { "name": "Session.range", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": df._plan.uuid, "query_plan_height": query_plan_height, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": { @@ -829,10 +835,15 @@ def test_dataframe_stat_functions_api_calls(session): column = 6 if session.sql_simplifier_enabled else 9 crosstab = df.stat.crosstab("empid", "month") + # uuid here is generated by an intermediate dataframe in crosstab implementation + # therefore we can't predict it. We check that the uuid for crosstab is same as + # that for df. + uuid = df._plan.api_calls[0]["plan_uuid"] assert crosstab._plan.api_calls == [ { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": uuid, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": {"group_by": 1, "column": column, "literal": 48}, @@ -851,6 +862,7 @@ def test_dataframe_stat_functions_api_calls(session): { "name": "Session.create_dataframe[values]", "sql_simplifier_enabled": session.sql_simplifier_enabled, + "plan_uuid": uuid, "query_plan_height": 4, "query_plan_num_duplicate_nodes": 0, "query_plan_complexity": {"group_by": 1, "column": column, "literal": 48}, @@ -1166,3 +1178,48 @@ def send_large_query_optimization_skipped_telemetry(): ) assert data == expected_data assert type_ == "snowpark_large_query_breakdown_optimization_skipped" + + +def test_post_compilation_stage_telemetry(session): + client = session._conn._telemetry_client + uuid_str = str(uuid.uuid4()) + + def send_telemetry(): + summary_value = { + "cte_optimization_enabled": True, + "large_query_breakdown_enabled": True, + "complexity_score_bounds": (300, 600), + "time_taken_for_compilation": 0.136, + "time_taken_for_deep_copy_plan": 0.074, + "time_taken_for_cte_optimization": 0.01, + "time_taken_for_large_query_breakdown": 0.062, + "complexity_score_before_compilation": 1148, + "complexity_score_after_cte_optimization": [1148], + "complexity_score_after_large_query_breakdown": [514, 636], + } + client.send_query_compilation_summary_telemetry( + session_id=session.session_id, + plan_uuid=uuid_str, + compilation_stage_summary=summary_value, + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "plan_uuid": uuid_str, + "cte_optimization_enabled": True, + "large_query_breakdown_enabled": True, + "complexity_score_bounds": (300, 600), + "time_taken_for_compilation": 0.136, + "time_taken_for_deep_copy_plan": 0.074, + "time_taken_for_cte_optimization": 0.01, + "time_taken_for_large_query_breakdown": 0.062, + "complexity_score_before_compilation": 1148, + "complexity_score_after_cte_optimization": [1148], + "complexity_score_after_large_query_breakdown": [514, 636], + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_compilation_stage_statistics" From dce73a534bc00ab85459a5a0e892f7a535e14ed9 Mon Sep 17 00:00:00 2001 From: Naresh Kumar <113932371+sfc-gh-nkumar@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:04:25 -0700 Subject: [PATCH 4/5] SNOW-1654416: Add support for TimedeltaIndex.mean (#2267) SNOW-1654416: Add support for TimedeltaIndex.mean --------- Co-authored-by: Naren Krishna --- CHANGELOG.md | 4 ++ .../supported/timedelta_index_supported.rst | 2 +- .../plugin/extensions/timedelta_index.py | 44 +++++++++++++++---- .../index/test_timedelta_index_methods.py | 26 +++++++++++ 4 files changed, 67 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92b240a7d6c..c59d53ada05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ - Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type. +#### New Features + +- Added support for `TimedeltaIndex.mean` method. + ## 1.22.0 (2024-09-10) ### Snowpark Python API Updates diff --git a/docs/source/modin/supported/timedelta_index_supported.rst b/docs/source/modin/supported/timedelta_index_supported.rst index 49dfcb305e4..f7a34c3552c 100644 --- a/docs/source/modin/supported/timedelta_index_supported.rst +++ b/docs/source/modin/supported/timedelta_index_supported.rst @@ -44,7 +44,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ | ``ceil`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ -| ``mean`` | N | | | +| ``mean`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ | ``total_seconds`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 96e2913f556..1dbb743aa32 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -32,7 +32,12 @@ from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable from pandas.core.dtypes.common import is_timedelta64_dtype +from snowflake.snowpark import functions as fn from snowflake.snowpark.modin.pandas import DataFrame, Series +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + AggregateColumnOpParameters, + aggregate_with_ordered_dataframe, +) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -40,6 +45,7 @@ from snowflake.snowpark.modin.plugin.utils.error_message import ( timedelta_index_not_implemented, ) +from snowflake.snowpark.types import LongType _CONSTRUCTOR_DEFAULTS = { "unit": lib.no_default, @@ -392,12 +398,11 @@ def to_pytimedelta(self) -> np.ndarray: datetime.timedelta(days=3)], dtype=object) """ - @timedelta_index_not_implemented() def mean( self, *, skipna: bool = True, axis: AxisInt | None = 0 - ) -> native_pd.Timestamp: + ) -> native_pd.Timedelta: """ - Return the mean value of the Array. + Return the mean value of the Timedelta values. Parameters ---------- @@ -407,17 +412,40 @@ def mean( Returns ------- - scalar Timestamp + scalar Timedelta + + Examples + -------- + >>> idx = pd.to_timedelta([1, 2, 3, 1], unit='D') + >>> idx + TimedeltaIndex(['1 days', '2 days', '3 days', '1 days'], dtype='timedelta64[ns]', freq=None) + >>> idx.mean() + Timedelta('1 days 18:00:00') See Also -------- numpy.ndarray.mean : Returns the average of array elements along a given axis. Series.mean : Return the mean value in a Series. - - Notes - ----- - mean is only defined for Datetime and Timedelta dtypes, not for Period. """ + if axis: + # Native pandas raises IndexError: tuple index out of range + # We raise a different more user-friendly error message. + raise ValueError( + f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'" + ) + # TODO SNOW-1620439: Reuse code from Series.mean. + frame = self._query_compiler._modin_frame + index_id = frame.index_column_snowflake_quoted_identifiers[0] + new_index_id = frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=["mean"] + )[0] + agg_column_op_params = AggregateColumnOpParameters( + index_id, LongType(), "mean", new_index_id, fn.mean, [] + ) + mean_value = aggregate_with_ordered_dataframe( + frame.ordered_dataframe, [agg_column_op_params], {"skipna": skipna} + ).collect()[0][0] + return native_pd.Timedelta(np.nan if mean_value is None else int(mean_value)) @timedelta_index_not_implemented() def as_unit(self, unit: str) -> TimedeltaIndex: diff --git a/tests/integ/modin/index/test_timedelta_index_methods.py b/tests/integ/modin/index/test_timedelta_index_methods.py index 25bef5364f2..c4d4a0b3a66 100644 --- a/tests/integ/modin/index/test_timedelta_index_methods.py +++ b/tests/integ/modin/index/test_timedelta_index_methods.py @@ -128,3 +128,29 @@ def test_timedelta_total_seconds(): native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA) snow_index = pd.Index(native_index) eval_snowpark_pandas_result(snow_index, native_index, lambda x: x.total_seconds()) + + +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.parametrize("data", [[1, 2, 3], [1, 2, 3, None], [None], []]) +@sql_count_checker(query_count=1) +def test_timedelta_index_mean(skipna, data): + native_index = native_pd.TimedeltaIndex(data) + snow_index = pd.Index(native_index) + native_result = native_index.mean(skipna=skipna) + snow_result = snow_index.mean(skipna=skipna) + # Special check for NaN because Nan != Nan. + if pd.isna(native_result): + assert pd.isna(snow_result) + else: + assert snow_result == native_result + + +@sql_count_checker(query_count=0) +def test_timedelta_index_mean_invalid_axis(): + native_index = native_pd.TimedeltaIndex([1, 2, 3]) + snow_index = pd.Index(native_index) + with pytest.raises(IndexError, match="tuple index out of range"): + native_index.mean(axis=1) + # Snowpark pandas raises ValueError instead of IndexError. + with pytest.raises(ValueError, match="axis should be 0 for TimedeltaIndex.mean"): + snow_index.mean(axis=1).to_pandas() From fe51d4d167760bb0a858de0e404e9d7a644fb3a9 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Wed, 11 Sep 2024 15:37:40 -0700 Subject: [PATCH 5/5] SNOW-1646704, SNOW-1646706: Add support for Series.dt.tz_convert/tz_localize (#2261) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1646704, SNOW-1646706 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. Add support for Series.dt.tz_convert/tz_localize. --- CHANGELOG.md | 1 + docs/source/modin/series.rst | 2 + .../modin/supported/series_dt_supported.rst | 5 +- .../modin/plugin/_internal/timestamp_utils.py | 65 +++++++ .../compiler/snowflake_query_compiler.py | 27 ++- .../modin/plugin/docstrings/series_utils.py | 175 +++++++++++++++++- tests/integ/modin/series/test_dt_accessor.py | 116 ++++++++++++ tests/unit/modin/test_series_dt.py | 2 - 8 files changed, 381 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c59d53ada05..87056a63f51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -118,6 +118,7 @@ - Added support for string indexing with `Timedelta` objects. - Added support for `Series.dt.total_seconds` method. - Added support for `DataFrame.apply(axis=0)`. +- Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`. #### Improvements diff --git a/docs/source/modin/series.rst b/docs/source/modin/series.rst index 188bdab344a..4cb8a238b0f 100644 --- a/docs/source/modin/series.rst +++ b/docs/source/modin/series.rst @@ -279,6 +279,8 @@ Series Series.dt.seconds Series.dt.microseconds Series.dt.nanoseconds + Series.dt.tz_convert + Series.dt.tz_localize .. rubric:: String accessor methods diff --git a/docs/source/modin/supported/series_dt_supported.rst b/docs/source/modin/supported/series_dt_supported.rst index 3377a3d64e2..68853871ea6 100644 --- a/docs/source/modin/supported/series_dt_supported.rst +++ b/docs/source/modin/supported/series_dt_supported.rst @@ -80,9 +80,10 @@ the method in the left column. +-----------------------------+---------------------------------+----------------------------------------------------+ | ``to_pydatetime`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | +| ``tz_localize`` | P | ``N`` if `ambiguous` or `nonexistent` are set to a | +| | | non-default value. | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | +| ``tz_convert`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``normalize`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index 0242177d1f0..f8629e664f3 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -22,9 +22,17 @@ cast, convert_timezone, date_part, + dayofmonth, + hour, iff, + minute, + month, + second, + timestamp_tz_from_parts, to_decimal, + to_timestamp_ntz, trunc, + year, ) from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage @@ -467,3 +475,60 @@ def convert_dateoffset_to_interval( ) interval_kwargs[new_param] = offset return Interval(**interval_kwargs) + + +def tz_localize_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: + """ + Localize tz-naive to tz-aware. + Args: + tz : str, pytz.timezone, optional + Localize a tz-naive datetime column to tz-aware + + Args: + column: the Snowpark datetime column + tz: time zone for time. Corresponding timestamps would be converted to this time zone of the Datetime Array/Index. A tz of None will convert to UTC and remove the timezone information. + + Returns: + The column after tz localization + """ + if tz is None: + # If this column is already a TIMESTAMP_NTZ, this cast does nothing. + # If the column is a TIMESTAMP_TZ, the cast drops the timezone and converts + # to TIMESTAMP_NTZ. + return to_timestamp_ntz(column) + else: + if isinstance(tz, dt.tzinfo): + tz_name = tz.tzname(None) + else: + tz_name = tz + return timestamp_tz_from_parts( + year(column), + month(column), + dayofmonth(column), + hour(column), + minute(column), + second(column), + date_part("nanosecond", column), + pandas_lit(tz_name), + ) + + +def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: + """ + Converts a datetime column to the specified timezone + + Args: + column: the Snowpark datetime column + tz: the target timezone + + Returns: + The column after conversion to the specified timezone + """ + if tz is None: + return convert_timezone(pandas_lit("UTC"), column) + else: + if isinstance(tz, dt.tzinfo): + tz_name = tz.tzname(None) + else: + tz_name = tz + return convert_timezone(pandas_lit(tz_name), column) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b1a2736d120..00436f94eec 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -279,6 +279,8 @@ raise_if_to_datetime_not_supported, timedelta_freq_to_nanos, to_snowflake_timestamp_format, + tz_convert_column, + tz_localize_column, ) from snowflake.snowpark.modin.plugin._internal.transpose_utils import ( clean_up_transpose_result_index_and_labels, @@ -16666,7 +16668,7 @@ def dt_tz_localize( tz: Union[str, tzinfo], ambiguous: str = "raise", nonexistent: str = "raise", - ) -> None: + ) -> "SnowflakeQueryCompiler": """ Localize tz-naive to tz-aware. Args: @@ -16678,11 +16680,22 @@ def dt_tz_localize( BaseQueryCompiler New QueryCompiler containing values with localized time zone. """ - ErrorMessage.not_implemented( - "Snowpark pandas doesn't yet support the method 'Series.dt.tz_localize'" + if not isinstance(ambiguous, str) or ambiguous != "raise": + ErrorMessage.parameter_not_implemented_error( + "ambiguous", "Series.dt.tz_localize" + ) + if not isinstance(nonexistent, str) or nonexistent != "raise": + ErrorMessage.parameter_not_implemented_error( + "nonexistent", "Series.dt.tz_localize" + ) + + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns( + lambda column: tz_localize_column(column, tz) + ) ) - def dt_tz_convert(self, tz: Union[str, tzinfo]) -> None: + def dt_tz_convert(self, tz: Union[str, tzinfo]) -> "SnowflakeQueryCompiler": """ Convert time-series data to the specified time zone. @@ -16692,8 +16705,10 @@ def dt_tz_convert(self, tz: Union[str, tzinfo]) -> None: Returns: A new QueryCompiler containing values with converted time zone. """ - ErrorMessage.not_implemented( - "Snowpark pandas doesn't yet support the method 'Series.dt.tz_convert'" + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns( + lambda column: tz_convert_column(column, tz) + ) ) def dt_ceil( diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py index 88c4029a92c..b05d7d76db6 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py @@ -1858,10 +1858,181 @@ def to_pydatetime(): pass def tz_localize(): - pass + """ + Localize tz-naive Datetime Array/Index to tz-aware Datetime Array/Index. + + This method takes a time zone (tz) naive Datetime Array/Index object and makes this time zone aware. It does not move the time to another time zone. + + This method can also be used to do the inverse – to create a time zone unaware object from an aware object. To that end, pass tz=None. + + Parameters + ---------- + tz : str, pytz.timezone, dateutil.tz.tzfile, datetime.tzinfo or None + Time zone to convert timestamps to. Passing None will remove the time zone information preserving local time. + ambiguous : ‘infer’, ‘NaT’, bool array, default ‘raise’ + When clocks moved backward due to DST, ambiguous times may arise. For example in Central European Time (UTC+01), when going from 03:00 DST to 02:00 non-DST, 02:30:00 local time occurs both at 00:30:00 UTC and at 01:30:00 UTC. In such a situation, the ambiguous parameter dictates how ambiguous times should be handled. + - ‘infer’ will attempt to infer fall dst-transition hours based on order + - bool-ndarray where True signifies a DST time, False signifies a non-DST time (note that this flag is only applicable for ambiguous times) + - ‘NaT’ will return NaT where there are ambiguous times + - ‘raise’ will raise an AmbiguousTimeError if there are ambiguous times. + nonexistent : ‘shift_forward’, ‘shift_backward, ‘NaT’, timedelta, default ‘raise’ + A nonexistent time does not exist in a particular timezone where clocks moved forward due to DST. + - ‘shift_forward’ will shift the nonexistent time forward to the closest existing time + - ‘shift_backward’ will shift the nonexistent time backward to the closest existing time + - ‘NaT’ will return NaT where there are nonexistent times + - timedelta objects will shift nonexistent times by the timedelta + - ‘raise’ will raise an NonExistentTimeError if there are nonexistent times. + + Returns + ------- + Same type as self + Array/Index converted to the specified time zone. + + Raises + ------ + TypeError + If the Datetime Array/Index is tz-aware and tz is not None. + + See also + -------- + DatetimeIndex.tz_convert + Convert tz-aware DatetimeIndex from one time zone to another. + + Examples + -------- + >>> tz_naive = pd.date_range('2018-03-01 09:00', periods=3) + >>> tz_naive + DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', + '2018-03-03 09:00:00'], + dtype='datetime64[ns]', freq=None) + + Localize DatetimeIndex in US/Eastern time zone: + + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP + >>> tz_aware # doctest: +SKIP + DatetimeIndex(['2018-03-01 09:00:00-05:00', + '2018-03-02 09:00:00-05:00', + '2018-03-03 09:00:00-05:00'], + dtype='datetime64[ns, US/Eastern]', freq=None) + + With the tz=None, we can remove the time zone information while keeping the local time (not converted to UTC): + + >>> tz_aware.tz_localize(None) # doctest: +SKIP + DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', + '2018-03-03 09:00:00'], + dtype='datetime64[ns]', freq=None) + + Be careful with DST changes. When there is sequential data, pandas can infer the DST time: + + >>> s = pd.to_datetime(pd.Series(['2018-10-28 01:30:00', + ... '2018-10-28 02:00:00', + ... '2018-10-28 02:30:00', + ... '2018-10-28 02:00:00', + ... '2018-10-28 02:30:00', + ... '2018-10-28 03:00:00', + ... '2018-10-28 03:30:00'])) + >>> s.dt.tz_localize('CET', ambiguous='infer') # doctest: +SKIP + 0 2018-10-28 01:30:00+02:00 + 1 2018-10-28 02:00:00+02:00 + 2 2018-10-28 02:30:00+02:00 + 3 2018-10-28 02:00:00+01:00 + 4 2018-10-28 02:30:00+01:00 + 5 2018-10-28 03:00:00+01:00 + 6 2018-10-28 03:30:00+01:00 + dtype: datetime64[ns, CET] + + In some cases, inferring the DST is impossible. In such cases, you can pass an ndarray to the ambiguous parameter to set the DST explicitly + + >>> s = pd.to_datetime(pd.Series(['2018-10-28 01:20:00', + ... '2018-10-28 02:36:00', + ... '2018-10-28 03:46:00'])) + >>> s.dt.tz_localize('CET', ambiguous=np.array([True, True, False])) # doctest: +SKIP + 0 2018-10-28 01:20:00+02:00 + 1 2018-10-28 02:36:00+02:00 + 2 2018-10-28 03:46:00+01:00 + dtype: datetime64[ns, CET] + + If the DST transition causes nonexistent times, you can shift these dates forward or backwards with a timedelta object or ‘shift_forward’ or ‘shift_backwards’. + + >>> s = pd.to_datetime(pd.Series(['2015-03-29 02:30:00', + ... '2015-03-29 03:30:00'])) + >>> s.dt.tz_localize('Europe/Warsaw', nonexistent='shift_forward') # doctest: +SKIP + 0 2015-03-29 03:00:00+02:00 + 1 2015-03-29 03:30:00+02:00 + dtype: datetime64[ns, Europe/Warsaw] + + >>> s.dt.tz_localize('Europe/Warsaw', nonexistent='shift_backward') # doctest: +SKIP + 0 2015-03-29 01:59:59.999999999+01:00 + 1 2015-03-29 03:30:00+02:00 + dtype: datetime64[ns, Europe/Warsaw] + + >>> s.dt.tz_localize('Europe/Warsaw', nonexistent=pd.Timedelta('1h')) # doctest: +SKIP + 0 2015-03-29 03:30:00+02:00 + 1 2015-03-29 03:30:00+02:00 + dtype: datetime64[ns, Europe/Warsaw] + """ def tz_convert(): - pass + """ + Convert tz-aware Datetime Array/Index from one time zone to another. + + Parameters + ---------- + tz : str, pytz.timezone, dateutil.tz.tzfile, datetime.tzinfo or None + Time zone for time. Corresponding timestamps would be converted to this time zone of the Datetime Array/Index. A tz of None will convert to UTC and remove the timezone information. + + Returns + ------- + Array or Index + + Raises + ------ + TypeError + If Datetime Array/Index is tz-naive. + + See also + DatetimeIndex.tz + A timezone that has a variable offset from UTC. + DatetimeIndex.tz_localize + Localize tz-naive DatetimeIndex to a given time zone, or remove timezone from a tz-aware DatetimeIndex. + + Examples + -------- + With the tz parameter, we can change the DatetimeIndex to other time zones: + + >>> dti = pd.date_range(start='2014-08-01 09:00', + ... freq='h', periods=3, tz='Europe/Berlin') # doctest: +SKIP + + >>> dti # doctest: +SKIP + DatetimeIndex(['2014-08-01 09:00:00+02:00', + '2014-08-01 10:00:00+02:00', + '2014-08-01 11:00:00+02:00'], + dtype='datetime64[ns, Europe/Berlin]', freq='h') + + >>> dti.tz_convert('US/Central') # doctest: +SKIP + DatetimeIndex(['2014-08-01 02:00:00-05:00', + '2014-08-01 03:00:00-05:00', + '2014-08-01 04:00:00-05:00'], + dtype='datetime64[ns, US/Central]', freq='h') + + With the tz=None, we can remove the timezone (after converting to UTC if necessary): + + >>> dti = pd.date_range(start='2014-08-01 09:00', freq='h', + ... periods=3, tz='Europe/Berlin') # doctest: +SKIP + + >>> dti # doctest: +SKIP + DatetimeIndex(['2014-08-01 09:00:00+02:00', + '2014-08-01 10:00:00+02:00', + '2014-08-01 11:00:00+02:00'], + dtype='datetime64[ns, Europe/Berlin]', freq='h') + + >>> dti.tz_convert(None) # doctest: +SKIP + DatetimeIndex(['2014-08-01 07:00:00', + '2014-08-01 08:00:00', + '2014-08-01 09:00:00'], + dtype='datetime64[ns]', freq='h') + """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. def normalize(): pass diff --git a/tests/integ/modin/series/test_dt_accessor.py b/tests/integ/modin/series/test_dt_accessor.py index 0e1cacf8fc0..d1795fa2c80 100644 --- a/tests/integ/modin/series/test_dt_accessor.py +++ b/tests/integ/modin/series/test_dt_accessor.py @@ -5,8 +5,10 @@ import datetime import modin.pandas as pd +import numpy as np import pandas as native_pd import pytest +import pytz import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker @@ -39,6 +41,47 @@ ) +timezones = pytest.mark.parametrize( + "tz", + [ + None, + # Use a subset of pytz.common_timezones containing a few timezones in each + *[ + param_for_one_tz + for tz in [ + "Africa/Abidjan", + "Africa/Timbuktu", + "America/Adak", + "America/Yellowknife", + "Antarctica/Casey", + "Asia/Dhaka", + "Asia/Manila", + "Asia/Shanghai", + "Atlantic/Stanley", + "Australia/Sydney", + "Canada/Pacific", + "Europe/Chisinau", + "Europe/Luxembourg", + "Indian/Christmas", + "Pacific/Chatham", + "Pacific/Wake", + "US/Arizona", + "US/Central", + "US/Eastern", + "US/Hawaii", + "US/Mountain", + "US/Pacific", + "UTC", + ] + for param_for_one_tz in ( + pytz.timezone(tz), + tz, + ) + ], + ], +) + + @pytest.fixture def day_of_week_or_year_data() -> native_pd.Series: return native_pd.Series( @@ -174,6 +217,79 @@ def test_normalize(): ) +@sql_count_checker(query_count=1) +@timezones +def test_tz_convert(tz): + datetime_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + tz="US/Eastern", + ) + native_ser = native_pd.Series(datetime_index) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda s: s.dt.tz_convert(tz), + ) + + +@sql_count_checker(query_count=1) +@timezones +def test_tz_localize(tz): + datetime_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + native_ser = native_pd.Series(datetime_index) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda s: s.dt.tz_localize(tz), + ) + + +@pytest.mark.parametrize( + "ambiguous, nonexistent", + [ + ("infer", "raise"), + ("NaT", "raise"), + (np.array([True, True, False]), "raise"), + ("raise", "shift_forward"), + ("raise", "shift_backward"), + ("raise", "NaT"), + ("raise", pd.Timedelta("1h")), + ("infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_tz_localize_negative(ambiguous, nonexistent): + datetime_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + native_ser = native_pd.Series(datetime_index) + snow_ser = pd.Series(native_ser) + with pytest.raises(NotImplementedError): + snow_ser.dt.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent) + + @pytest.mark.parametrize("name", [None, "hello"]) def test_isocalendar(name): with SqlCounter(query_count=1): diff --git a/tests/unit/modin/test_series_dt.py b/tests/unit/modin/test_series_dt.py index be0039683a8..0b5572f0592 100644 --- a/tests/unit/modin/test_series_dt.py +++ b/tests/unit/modin/test_series_dt.py @@ -32,8 +32,6 @@ def mock_query_compiler_for_dt_series() -> SnowflakeQueryCompiler: [ (lambda s: s.dt.timetz, "timetz"), (lambda s: s.dt.to_period(), "to_period"), - (lambda s: s.dt.tz_localize(tz="UTC"), "tz_localize"), - (lambda s: s.dt.tz_convert(tz="UTC"), "tz_convert"), (lambda s: s.dt.strftime(date_format="YY/MM/DD"), "strftime"), (lambda s: s.dt.qyear, "qyear"), (lambda s: s.dt.start_time, "start_time"),