From d4b4638b7b2e9f007aad3967cdf4d38a4ec48c89 Mon Sep 17 00:00:00 2001 From: Mahesh Vashishtha Date: Tue, 20 Aug 2024 20:25:34 -0700 Subject: [PATCH] SNOW-1320449: Support subtracting two timestamps to get a timedelta. (#2113) Signed-off-by: sfc-gh-mvashishtha Co-authored-by: Naren Krishna --- CHANGELOG.md | 1 + .../modin/plugin/_internal/binary_op_utils.py | 135 +++++- .../snowpark/modin/plugin/_internal/frame.py | 39 +- .../plugin/_internal/snowpark_pandas_types.py | 12 +- .../compiler/snowflake_query_compiler.py | 161 ++++--- tests/integ/modin/binary/test_timedelta.py | 396 ++++++++++++++++++ 6 files changed, 657 insertions(+), 87 deletions(-) create mode 100644 tests/integ/modin/binary/test_timedelta.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fde2cd8ec5..9c017b60188 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ - Added support for `Index.__repr__`. - Added support for `DatetimeIndex.month_name` and `DatetimeIndex.day_name`. - Added support for `Series.dt.weekday`, `Series.dt.time`, and `DatetimeIndex.time`. +- Added support for subtracting two timestamps to get a Timedelta. #### Bug Fixes diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py index bd48298c5b0..5852f0b8a27 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py @@ -4,15 +4,30 @@ import functools from collections.abc import Hashable from dataclasses import dataclass +from enum import Enum, auto +import pandas as native_pd from pandas._typing import Callable, Scalar from snowflake.snowpark.column import Column as SnowparkColumn -from snowflake.snowpark.functions import col, concat, floor, iff, repeat, when +from snowflake.snowpark.functions import ( + col, + concat, + datediff, + floor, + iff, + is_null, + repeat, + when, +) from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame from snowflake.snowpark.modin.plugin._internal.join_utils import ( JoinOrAlignInternalFrameResult, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + SnowparkPandasColumn, + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.type_utils import infer_object_type from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage @@ -20,6 +35,8 @@ DataType, NullType, StringType, + TimestampTimeZone, + TimestampType, _FractionalType, _IntegralType, ) @@ -172,13 +189,64 @@ def is_binary_op_supported(op: str) -> bool: return op in SUPPORTED_BINARY_OPERATIONS +class SubtractionType(Enum): + """Type of subtraction, i.e. rsub or sub""" + + # SUB means first_operand - second_operand, e.g. pd.Series(2).sub(pd.Series(1)) is equal to pd.Series(1) + SUB = auto() + # RSUB means second_operand - first_operand, e.g. pd.Series(2).rsub(pd.Series(1)) is equal to pd.Series(-1) + RSUB = auto() + + +def _compute_subtraction_between_snowpark_timestamp_columns( + first_operand: SnowparkColumn, + first_datatype: DataType, + second_operand: SnowparkColumn, + second_datatype: DataType, + subtraction_type: SubtractionType, +) -> SnowparkPandasColumn: + """ + Compute subtraction between two snowpark columns. + + Args: + first_operand: SnowparkColumn for lhs + first_datatype: Snowpark datatype for lhs + second_operand: SnowparkColumn for rhs + second_datatype: Snowpark datatype for rhs + subtraction_type: Type of subtraction. + """ + if ( + first_datatype.tz is TimestampTimeZone.NTZ + and second_datatype.tz is TimestampTimeZone.TZ + ) or ( + first_datatype.tz is TimestampTimeZone.TZ + and second_datatype.tz is TimestampTimeZone.NTZ + ): + raise TypeError("Cannot subtract tz-naive and tz-aware datetime-like objects.") + return SnowparkPandasColumn( + iff( + is_null(first_operand).__or__(is_null(second_operand)), + pandas_lit(native_pd.NaT), + datediff( + "ns", + *( + (first_operand, second_operand) + if subtraction_type is SubtractionType.RSUB + else (second_operand, first_operand) + ), + ), + ), + TimedeltaType(), + ) + + def compute_binary_op_between_snowpark_columns( op: str, first_operand: SnowparkColumn, first_datatype: Callable[[], DataType], second_operand: SnowparkColumn, second_datatype: Callable[[], DataType], -) -> SnowparkColumn: +) -> SnowparkPandasColumn: """ Compute pandas binary operation for two SnowparkColumns Args: @@ -191,10 +259,9 @@ def compute_binary_op_between_snowpark_columns( it is not needed. Returns: - SnowparkColumn expr for translated pandas operation + SnowparkPandasColumn for translated pandas operation """ - - binary_op_result_column = None + binary_op_result_column, snowpark_pandas_result_type = None, None # some operators and the data types have to be handled specially to align with pandas # However, it is difficult to fail early if the arithmetic operator is not compatible @@ -272,14 +339,48 @@ def compute_binary_op_between_snowpark_columns( binary_op_result_column = pandas_lit(False) else: binary_op_result_column = first_operand.equal_null(second_operand) - + elif ( + op in ("rsub", "sub") + and isinstance(first_datatype(), TimestampType) + and isinstance(second_datatype(), NullType) + ): + # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, + # but it's valid in pandas and returns NULL. + snowpark_pandas_result_type = NullType() + binary_op_result_column = pandas_lit(None) + elif ( + op in ("rsub", "sub") + and isinstance(first_datatype(), NullType) + and isinstance(second_datatype(), TimestampType) + ): + # Timestamp - NULL or NULL - Timestamp raises SQL compilation error, + # but it's valid in pandas and returns NULL. + snowpark_pandas_result_type = NullType() + binary_op_result_column = pandas_lit(None) + elif ( + op in ("sub", "rsub") + and isinstance(first_datatype(), TimestampType) + and isinstance(second_datatype(), TimestampType) + ): + return _compute_subtraction_between_snowpark_timestamp_columns( + first_operand=first_operand, + first_datatype=first_datatype(), + second_operand=second_operand, + second_datatype=second_datatype(), + subtraction_type=SubtractionType.SUB + if op == "sub" + else SubtractionType.RSUB, + ) # If there is no special binary_op_result_column result, it means the operator and # the data type of the column don't need special handling. Then we get the overloaded # operator from Snowpark Column class, e.g., __add__ to perform binary operations. if binary_op_result_column is None: binary_op_result_column = getattr(first_operand, f"__{op}__")(second_operand) - return binary_op_result_column + return SnowparkPandasColumn( + snowpark_column=binary_op_result_column, + snowpark_pandas_type=snowpark_pandas_result_type, + ) def are_equal_types(type1: DataType, type2: DataType) -> bool: @@ -307,7 +408,7 @@ def compute_binary_op_between_snowpark_column_and_scalar( first_operand: SnowparkColumn, datatype: Callable[[], DataType], second_operand: Scalar, -) -> SnowparkColumn: +) -> SnowparkPandasColumn: """ Compute the binary operation between a Snowpark column and a scalar. Args: @@ -318,16 +419,14 @@ def compute_binary_op_between_snowpark_column_and_scalar( second_operand: Scalar value Returns: - The result as a Snowpark column + SnowparkPandasColumn for translated pandas operation """ def second_datatype() -> DataType: return infer_object_type(second_operand) - second_operand = pandas_lit(second_operand) - return compute_binary_op_between_snowpark_columns( - op, first_operand, datatype, second_operand, second_datatype + op, first_operand, datatype, pandas_lit(second_operand), second_datatype ) @@ -336,7 +435,7 @@ def compute_binary_op_between_scalar_and_snowpark_column( first_operand: Scalar, second_operand: SnowparkColumn, datatype: Callable[[], DataType], -) -> SnowparkColumn: +) -> SnowparkPandasColumn: """ Compute the binary operation between a scalar and a Snowpark column. Args: @@ -347,16 +446,14 @@ def compute_binary_op_between_scalar_and_snowpark_column( it is not needed. Returns: - The result as a Snowpark column + SnowparkPandasColumn for translated pandas operation """ def first_datatype() -> DataType: return infer_object_type(first_operand) - first_operand = pandas_lit(first_operand) - return compute_binary_op_between_snowpark_columns( - op, first_operand, first_datatype, second_operand, datatype + op, pandas_lit(first_operand), first_datatype, second_operand, datatype ) @@ -367,7 +464,7 @@ def compute_binary_op_with_fill_value( rhs: SnowparkColumn, rhs_datatype: Callable[[], DataType], fill_value: Scalar, -) -> SnowparkColumn: +) -> SnowparkPandasColumn: """ Helper method for performing binary operations. 1. Fills NaN/None values in the lhs and rhs with the given fill_value. @@ -392,7 +489,7 @@ def compute_binary_op_with_fill_value( successful DataFrame alignment, with this value before computation. Returns: - SnowparkColumn expression for translated pandas operation + SnowparkPandasColumn for translated pandas operation """ lhs_cond, rhs_cond = lhs, rhs if fill_value is not None: diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py index 36ab44097e9..9834d54eafb 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/frame.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -924,7 +924,10 @@ def persist_to_temporary_table(self) -> "InternalFrame": ) def append_column( - self, pandas_label: Hashable, value: SnowparkColumn + self, + pandas_label: Hashable, + value: SnowparkColumn, + value_type: Optional[SnowparkPandasType] = None, ) -> "InternalFrame": """ Append a column to this frame. The column is added at the end. For a frame with multiindex column, it @@ -935,6 +938,7 @@ def append_column( Args: pandas_label: pandas label for column to be inserted. value: SnowparkColumn. + value_type: The optional SnowparkPandasType for the new column. Returns: A new InternalFrame with new column. @@ -975,7 +979,8 @@ def append_column( data_column_pandas_index_names=self.data_column_pandas_index_names, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, - data_column_types=self.cached_data_column_snowpark_pandas_types + [None], + data_column_types=self.cached_data_column_snowpark_pandas_types + + [value_type], index_column_types=self.cached_index_column_snowpark_pandas_types, ) @@ -1109,6 +1114,9 @@ def get_updated_identifiers(identifiers: list[str]) -> list[str]: def update_snowflake_quoted_identifiers_with_expressions( self, quoted_identifier_to_column_map: dict[str, SnowparkColumn], + data_column_snowpark_pandas_types: Optional[ + list[Optional[SnowparkPandasType]] + ] = None, ) -> UpdatedInternalFrameResult: """ Points Snowflake quoted identifiers to column expression given by `quoted_identifier_to_column_map`. @@ -1134,6 +1142,8 @@ def update_snowflake_quoted_identifiers_with_expressions( existing snowflake quoted identifiers to new Snowpark columns. As keys of a dictionary, all snowflake column identifiers are unique here and must be index columns and data columns in the original internal frame. + data_column_snowpark_pandas_types: The optional Snowpark pandas types for the new + expressions, in the order of the keys of quoted_identifier_to_column_map. Returns: UpdatedInternalFrameResult: A tuple contaning the new InternalFrame with updated column references, and a mapping @@ -1168,10 +1178,16 @@ def update_snowflake_quoted_identifiers_with_expressions( existing_id_to_new_id_mapping = {} new_columns = [] - for ( - existing_identifier, - column_expression, - ) in quoted_identifier_to_column_map.items(): + new_type_mapping = dict( + self.snowflake_quoted_identifier_to_snowpark_pandas_type + ) + if data_column_snowpark_pandas_types is None: + data_column_snowpark_pandas_types = [None] * len( + quoted_identifier_to_column_map + ) + for ((existing_identifier, column_expression,), data_column_type) in zip( + quoted_identifier_to_column_map.items(), data_column_snowpark_pandas_types + ): new_identifier = ( self.ordered_dataframe.generate_snowflake_quoted_identifiers( pandas_labels=[ @@ -1183,6 +1199,7 @@ def update_snowflake_quoted_identifiers_with_expressions( ) existing_id_to_new_id_mapping[existing_identifier] = new_identifier new_columns.append(column_expression) + new_type_mapping[new_identifier] = data_column_type new_ordered_dataframe = append_columns( self.ordered_dataframe, list(existing_id_to_new_id_mapping.values()), @@ -1206,10 +1223,16 @@ def update_snowflake_quoted_identifiers_with_expressions( data_column_pandas_labels=self.data_column_pandas_labels, data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=self.data_column_pandas_index_names, - data_column_types=self.cached_data_column_snowpark_pandas_types, index_column_pandas_labels=self.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers, - index_column_types=self.cached_index_column_snowpark_pandas_types, + data_column_types=[ + new_type_mapping[k] + for k in new_data_column_snowflake_quoted_identifiers + ], + index_column_types=[ + new_type_mapping[k] + for k in new_index_column_snowflake_quoted_identifiers + ], ), existing_id_to_new_id_mapping, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py b/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py index d87ce659a98..0f58a94fd03 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py @@ -6,11 +6,12 @@ import inspect from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, NamedTuple, Optional, Union import numpy as np import pandas as native_pd +from snowflake.snowpark.column import Column from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage from snowflake.snowpark.types import LongType @@ -99,6 +100,15 @@ def get_snowpark_pandas_type_for_pandas_type( return _pandas_type_to_snowpark_pandas_type.get(pandas_type, None) +class SnowparkPandasColumn(NamedTuple): + """A Snowpark Column that has an optional SnowparkPandasType.""" + + # The Snowpark Column. + snowpark_column: Column + # The SnowparkPandasType for the column, if the type of the column is a SnowparkPandasType. + snowpark_pandas_type: Optional[SnowparkPandasType] + + class TimedeltaType(SnowparkPandasType, LongType): """ Timedelta represents the difference between two times. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index d851918654d..eab23c63616 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -262,6 +262,8 @@ validate_resample_supported_by_snowflake, ) from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + SnowparkPandasColumn, + SnowparkPandasType, TimedeltaType, ) from snowflake.snowpark.modin.plugin._internal.timestamp_utils import ( @@ -1483,7 +1485,9 @@ def _shift_values_axis_0( fill_value = pandas_lit(fill_value) if fill_value is not None else None type_map = frame.quoted_identifier_to_snowflake_type() - def shift_expression(quoted_identifier: str, dtype: DataType) -> SnowparkColumn: + def shift_expression_and_type( + quoted_identifier: str, dtype: DataType + ) -> SnowparkPandasColumn: """ Helper function to generate lag-based shift expression for Snowpark pandas. Performs necessary type conversion if datatype of fill_value is not compatible with a column's datatype. @@ -1492,29 +1496,45 @@ def shift_expression(quoted_identifier: str, dtype: DataType) -> SnowparkColumn: dtype: datatype of column identified by quoted_identifier Returns: - SnowparkColumn columnar expression + SnowparkPandasColumn representing the result. """ window_expr = Window.orderBy(col(row_position_quoted_identifier)) # convert to variant type if types differ if fill_value is not None and dtype != fill_value_dtype: - return lag( + shift_expression = lag( to_variant(col(quoted_identifier)), offset=periods, default_value=to_variant(fill_value), ).over(window_expr) + expression_type = VariantType() else: - return lag( + shift_expression = lag( quoted_identifier, offset=periods, default_value=fill_value ).over(window_expr) - + expression_type = dtype + # TODO(https://snowflakecomputing.atlassian.net/browse/SNOW-1634393): + # Prevent ourselves from using types that are DataType but not + # SnowparkPandasType. In this particular case, the type should + # indeed be Optional[SnowparkPandasType] + return ( + shift_expression, + expression_type + if isinstance(expression_type, SnowparkPandasType) + else None, + ) + + quoted_identifier_to_column_map = {} + data_column_snowpark_pandas_types = [] + for identifier in frame.data_column_snowflake_quoted_identifiers: + expression, snowpark_pandas_type = shift_expression_and_type( + identifier, type_map[identifier] + ) + quoted_identifier_to_column_map[identifier] = expression + data_column_snowpark_pandas_types.append(snowpark_pandas_type) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( - { - quoted_identifier: shift_expression( - quoted_identifier, type_map[quoted_identifier] - ) - for quoted_identifier in frame.data_column_snowflake_quoted_identifiers - } + quoted_identifier_to_column_map=quoted_identifier_to_column_map, + data_column_snowpark_pandas_types=data_column_snowpark_pandas_types, ).frame return self.__constructor__(new_frame) @@ -1834,8 +1854,10 @@ def _binary_op_scalar_rhs( only arithmetic binary operation has this parameter (e.g., add() has, but eq() doesn't have). """ type_map = self._modin_frame.quoted_identifier_to_snowflake_type() - replace_mapping = { - identifier: compute_binary_op_with_fill_value( + replace_mapping = {} + data_column_snowpark_pandas_types = [] + for identifier in self._modin_frame.data_column_snowflake_quoted_identifiers: + expression, snowpark_pandas_type = compute_binary_op_with_fill_value( op=op, lhs=col(identifier), lhs_datatype=lambda: type_map[identifier], # noqa: B023 @@ -1843,11 +1865,12 @@ def _binary_op_scalar_rhs( rhs_datatype=lambda: infer_object_type(other), fill_value=fill_value, ) - for identifier in self._modin_frame.data_column_snowflake_quoted_identifiers - } + replace_mapping[identifier] = expression + data_column_snowpark_pandas_types.append(snowpark_pandas_type) return SnowflakeQueryCompiler( self._modin_frame.update_snowflake_quoted_identifiers_with_expressions( - replace_mapping + quoted_identifier_to_column_map=replace_mapping, + data_column_snowpark_pandas_types=data_column_snowpark_pandas_types, ).frame ) @@ -1892,8 +1915,10 @@ def _binary_op_list_like_rhs_axis_0( other_identifier = new_frame.data_column_snowflake_quoted_identifiers[-1] # Step 3: Create a map from the column identifier to the binary operation expression. This is used # to update the column data. - replace_mapping = { - identifier: compute_binary_op_with_fill_value( + replace_mapping = {} + snowpark_pandas_types = [] + for identifier in new_frame.data_column_snowflake_quoted_identifiers[:-1]: + expression, snowpark_pandas_type = compute_binary_op_with_fill_value( op=op, lhs=col(identifier), lhs_datatype=lambda: identifier_to_type_map[identifier], # noqa: B023 @@ -1901,8 +1926,8 @@ def _binary_op_list_like_rhs_axis_0( rhs_datatype=lambda: identifier_to_type_map[other_identifier], fill_value=fill_value, ) - for identifier in new_frame.data_column_snowflake_quoted_identifiers[:-1] - } + replace_mapping[identifier] = expression + snowpark_pandas_types.append(snowpark_pandas_type) # Step 4: Update the frame with the expressions map and return a new query compiler after removing the # column representing other's data. @@ -1918,7 +1943,7 @@ def _binary_op_list_like_rhs_axis_0( data_column_pandas_index_names=new_frame.data_column_pandas_index_names, index_column_pandas_labels=new_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_frame.index_column_snowflake_quoted_identifiers, - data_column_types=None, + data_column_types=snowpark_pandas_types, index_column_types=None, ) return SnowflakeQueryCompiler(new_frame) @@ -1950,6 +1975,7 @@ def _binary_op_list_like_rhs_axis_1( # each element in the list-like object can be treated as a scalar for each corresponding column. type_map = self._modin_frame.quoted_identifier_to_snowflake_type() + snowpark_pandas_types = [] for idx, identifier in enumerate( self._modin_frame.data_column_snowflake_quoted_identifiers ): @@ -1963,7 +1989,7 @@ def _binary_op_list_like_rhs_axis_1( # rhs is not guaranteed to be a scalar value - it can be a list-like as well. # Convert all list-like objects to a list. rhs_lit = pandas_lit(rhs) if is_scalar(rhs) else pandas_lit(rhs.tolist()) - replace_mapping[identifier] = compute_binary_op_with_fill_value( + expression, snowpark_pandas_type = compute_binary_op_with_fill_value( op, lhs=lhs, lhs_datatype=lambda: type_map[identifier], # noqa: B023 @@ -1971,10 +1997,12 @@ def _binary_op_list_like_rhs_axis_1( rhs_datatype=lambda: infer_object_type(rhs), # noqa: B023 fill_value=fill_value, ) + replace_mapping[identifier] = expression + snowpark_pandas_types.append(snowpark_pandas_type) return SnowflakeQueryCompiler( self._modin_frame.update_snowflake_quoted_identifiers_with_expressions( - replace_mapping + replace_mapping, snowpark_pandas_types ).frame ) @@ -2096,7 +2124,7 @@ def binary_op( )[0] # add new column with result as unnamed - new_column_expr = compute_binary_op_with_fill_value( + new_column_expr, snowpark_pandas_type = compute_binary_op_with_fill_value( op=op, lhs=col(lhs_quoted_identifier), lhs_datatype=lambda: identifier_to_type_map[lhs_quoted_identifier], @@ -2113,7 +2141,9 @@ def binary_op( else lhs_frame.data_column_pandas_labels[0] ) - new_frame = aligned_frame.append_column(new_column_name, new_column_expr) + new_frame = aligned_frame.append_column( + new_column_name, new_column_expr, value_type=snowpark_pandas_type + ) # return only newly created column. Because column has been appended, this is the last column indexed by -1 return SnowflakeQueryCompiler( @@ -13675,15 +13705,19 @@ def create_lazy_type_functions( right_datatype = right_datatypes[0] # now replace in result frame identifiers with binary op result + replace_mapping = {} + snowpark_pandas_types = [] + for left, left_datatype in zip(left_result_data_identifiers, left_datatypes): + ( + expression, + snowpark_pandas_type, + ) = compute_binary_op_between_snowpark_columns( + op, col(left), left_datatype, col(right), right_datatype + ) + snowpark_pandas_types.append(snowpark_pandas_type) + replace_mapping[left] = expression update_result = joined_frame.result_frame.update_snowflake_quoted_identifiers_with_expressions( - { - left: compute_binary_op_between_snowpark_columns( - op, col(left), left_datatype, col(right), right_datatype - ) - for left, left_datatype in zip( - left_result_data_identifiers, left_datatypes - ) - } + replace_mapping, snowpark_pandas_types ) new_frame = update_result.frame @@ -13691,22 +13725,25 @@ def create_lazy_type_functions( identifiers_to_keep = set( new_frame.index_column_snowflake_quoted_identifiers ) | set(update_result.old_id_to_new_id_mappings.values()) - label_to_snowflake_quoted_identifier = tuple( - filter( - lambda pair: pair.snowflake_quoted_identifier in identifiers_to_keep, - new_frame.label_to_snowflake_quoted_identifier, - ) - ) + label_to_snowflake_quoted_identifier = [] + snowflake_quoted_identifier_to_snowpark_pandas_type = {} + for pair in new_frame.label_to_snowflake_quoted_identifier: + if pair.snowflake_quoted_identifier in identifiers_to_keep: + label_to_snowflake_quoted_identifier.append(pair) + snowflake_quoted_identifier_to_snowpark_pandas_type[ + pair.snowflake_quoted_identifier + ] = new_frame.snowflake_quoted_identifier_to_snowpark_pandas_type[ + pair.snowflake_quoted_identifier + ] new_frame = InternalFrame( ordered_dataframe=new_frame.ordered_dataframe, - label_to_snowflake_quoted_identifier=label_to_snowflake_quoted_identifier, + label_to_snowflake_quoted_identifier=tuple( + label_to_snowflake_quoted_identifier + ), num_index_columns=new_frame.num_index_columns, data_column_index_names=new_frame.data_column_index_names, - snowflake_quoted_identifier_to_snowpark_pandas_type={ - pair.snowflake_quoted_identifier: None - for pair in label_to_snowflake_quoted_identifier - }, + snowflake_quoted_identifier_to_snowpark_pandas_type=snowflake_quoted_identifier_to_snowpark_pandas_type, ) return SnowflakeQueryCompiler(new_frame) @@ -13923,8 +13960,10 @@ def infer_sorted_column_labels( align_result, combined_data_labels, self_frame, other_frame ) - replace_mapping = { - p.identifier: compute_binary_op_with_fill_value( + replace_mapping = {} + data_column_snowpark_pandas_types = [] + for p in left_right_pairs: + result_expression, snowpark_pandas_type = compute_binary_op_with_fill_value( op=op, lhs=p.lhs, lhs_datatype=p.lhs_datatype, @@ -13932,9 +13971,8 @@ def infer_sorted_column_labels( rhs_datatype=p.rhs_datatype, fill_value=fill_value, ) - for p in left_right_pairs - } - + replace_mapping[p.identifier] = result_expression + data_column_snowpark_pandas_types.append(snowpark_pandas_type) # Create restricted frame with only combined / replaced labels. updated_result = align_result.result_frame.update_snowflake_quoted_identifiers_with_expressions( replace_mapping @@ -13951,7 +13989,7 @@ def infer_sorted_column_labels( data_column_snowflake_quoted_identifiers=updated_data_identifiers, index_column_pandas_labels=new_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_frame.index_column_snowflake_quoted_identifiers, - data_column_types=None, + data_column_types=data_column_snowpark_pandas_types, index_column_types=None, ) @@ -14192,9 +14230,11 @@ def infer_sorted_column_labels( for _, identifier in overlapping_pairs # noqa: B023 } - new_frame = new_frame.update_snowflake_quoted_identifiers_with_expressions( - { - identifier: compute_binary_op_between_scalar_and_snowpark_column( + replace_mapping = {} + snowpark_pandas_labels = [] + for label, identifier in overlapping_pairs: + expression, new_type = ( + compute_binary_op_between_scalar_and_snowpark_column( op, series.loc[label], col(identifier), @@ -14207,11 +14247,14 @@ def infer_sorted_column_labels( datatype_getters[identifier], series.loc[label], ) - for label, identifier in overlapping_pairs - } - ).frame - - return SnowflakeQueryCompiler(new_frame) + ) + snowpark_pandas_labels.append(new_type) + replace_mapping[identifier] = expression + return SnowflakeQueryCompiler( + new_frame.update_snowflake_quoted_identifiers_with_expressions( + replace_mapping, snowpark_pandas_labels + ).frame + ) def _replace_non_str( self, @@ -16644,7 +16687,7 @@ def equals( replace_mapping = { p.identifier: compute_binary_op_between_snowpark_columns( "equal_null", p.lhs, p.lhs_datatype, p.rhs, p.rhs_datatype - ) + ).snowpark_column for p in left_right_pairs } diff --git a/tests/integ/modin/binary/test_timedelta.py b/tests/integ/modin/binary/test_timedelta.py new file mode 100644 index 00000000000..72e3a4d75a9 --- /dev/null +++ b/tests/integ/modin/binary/test_timedelta.py @@ -0,0 +1,396 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import datetime +import re + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.utils import ( + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + create_test_dfs, + create_test_series, + eval_snowpark_pandas_result, +) + +PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1 = ( + native_pd.Series( + [ + None, + pd.Timestamp(year=1994, month=7, day=29), + None, + pd.Timestamp(year=1996, month=1, day=23), + pd.Timestamp(year=2000, month=1, day=23), + pd.Timestamp( + year=1700, month=1, day=1, second=22, microsecond=12345, nanosecond=56 + ), + ] + ), + native_pd.Series( + [ + None, + None, + pd.Timestamp(year=1995, month=7, day=29), + pd.Timestamp(year=1996, month=1, day=24), + pd.Timestamp(year=2024, month=7, day=8), + pd.Timestamp( + year=1700, month=1, day=2, second=49, microsecond=7, nanosecond=98 + ), + ] + ), +) + + +class TestDataFrameAndScalar: + @pytest.mark.parametrize( + "scalar", + [ + pd.NaT, + datetime.datetime(year=2024, month=8, day=14, hour=2, minute=32, second=42), + datetime.datetime(year=2023, month=3, day=14), + pd.Timestamp(year=2020, month=3, day=25), + ], + ) + @pytest.mark.parametrize("operation", ["sub", "rsub"]) + @sql_count_checker(query_count=1) + def test_timestamp_minus_timestamp(self, scalar, operation): + eval_snowpark_pandas_result( + *create_test_dfs( + [ + [datetime.datetime(year=2024, month=1, day=1), pd.NaT], + [ + datetime.datetime(year=2023, month=1, day=1), + datetime.datetime(year=2030, month=1, day=1), + ], + ] + ), + lambda df: getattr(df, operation)(scalar), + ) + + +class TestSeriesAndScalar: + @pytest.mark.parametrize( + "scalar", + [ + pd.NaT, + datetime.datetime(year=2024, month=8, day=14, hour=2, minute=32, second=42), + datetime.datetime(year=2023, month=3, day=14), + pd.Timestamp(year=2020, month=3, day=25), + ], + ) + @sql_count_checker(query_count=1) + @pytest.mark.parametrize("operation", ["sub", "rsub"]) + def test_timestamp_minus_timestamp(self, operation, scalar): + eval_snowpark_pandas_result( + *create_test_series( + [ + datetime.datetime(year=2024, month=1, day=1), + pd.NaT, + datetime.datetime(year=2023, month=1, day=1), + datetime.datetime(year=2030, month=1, day=1), + ] + ), + lambda series: getattr(series, operation)(scalar), + ) + + +class TestDataFrameAndListLikeAxis1: + @pytest.mark.parametrize("op", ["sub", "rsub"]) + @sql_count_checker(query_count=1) + def test_timestamp_minus_timestamp(self, op): + eval_snowpark_pandas_result( + *create_test_dfs( + [ + [ + pd.Timestamp(5, unit="ns"), + pd.Timestamp(700, unit="ns"), + pd.Timestamp(1399, unit="ns"), + ], + [ + pd.Timestamp(6, unit="ms"), + pd.Timestamp(800, unit="ms"), + pd.Timestamp(1499, unit="ms"), + ], + ] + ), + lambda df: getattr(df, op)( + [ + pd.Timestamp(1, unit="ns"), + pd.Timestamp(300, unit="ns"), + pd.Timestamp(57, unit="ms"), + ] + ), + ) + + +class TestSeriesAndListLike: + @sql_count_checker(query_count=1, join_count=1) + @pytest.mark.parametrize("op", ["sub", "rsub"]) + def test_timestamp_minus_timestamp(self, op): + eval_snowpark_pandas_result( + *create_test_series( + [ + pd.Timestamp(5, unit="ns"), + pd.Timestamp(700, unit="ns"), + pd.Timestamp(1399, unit="ns"), + ] + ), + lambda series: getattr(series, op)( + [ + pd.Timestamp(1, unit="ns"), + pd.Timestamp(300, unit="ns"), + pd.Timestamp(999, unit="ns"), + ] + ), + ) + + +class TestDataFrameAndListLikeAxis0: + @sql_count_checker(query_count=1, join_count=1) + @pytest.mark.parametrize("op", ["sub", "rsub"]) + def test_timestamp_minus_timestamp(self, op): + eval_snowpark_pandas_result( + *create_test_dfs( + [ + [pd.Timestamp(5, unit="ns"), pd.Timestamp(700, unit="ns")], + [pd.Timestamp(6, unit="ns"), pd.Timestamp(800, unit="ns")], + [pd.Timestamp(7, unit="ns"), pd.Timestamp(900, unit="ns")], + ] + ), + lambda df: getattr(df, op)( + [ + pd.Timestamp(1, unit="ns"), + pd.Timestamp(300, unit="ns"), + pd.Timestamp(999, unit="ns"), + ], + axis=0, + ), + ) + + +class TestSeriesAndSeries: + @pytest.mark.parametrize( + "pandas_lhs,pandas_rhs", + [ + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1, + [ + x.dt.tz_localize("UTC") + for x in PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1 + ], + ( + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1[0].dt.tz_localize( + "UTC" + ), + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1[1].dt.tz_localize( + "Asia/Kolkata" + ), + ), + ], + ) + @pytest.mark.parametrize("op", ["sub", "rsub"]) + @sql_count_checker(query_count=1, join_count=1) + def test_timestamp_minus_timestamp(self, pandas_lhs, pandas_rhs, op): + """Subtract two series of timestamps to get a timedelta.""" + snow_lhs = pd.Series(pandas_lhs) + snow_rhs = pd.Series(pandas_rhs) + eval_snowpark_pandas_result( + (snow_lhs, snow_rhs), + (pandas_lhs, pandas_rhs), + lambda inputs: getattr(inputs[0], op)(inputs[1]), + ) + + @pytest.mark.parametrize( + "pandas_lhs,pandas_rhs", + [ + ( + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1[0].dt.tz_localize( + "UTC" + ), + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1[1], + ), + ( + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1[0], + PANDAS_TIMESTAMP_SERIES_WITH_NULLS_NO_TIMEZONE_1[1].dt.tz_localize( + "UTC" + ), + ), + ], + ) + @pytest.mark.parametrize("op", ["sub", "rsub"]) + @sql_count_checker(query_count=0) + def test_subtract_two_timestamps_timezones_disallowed( + self, pandas_lhs, pandas_rhs, op + ): + snow_lhs = pd.Series(pandas_lhs) + snow_rhs = pd.Series(pandas_rhs) + # pandas is inconsistent about including a period at the end of the end + # of the error message, but Snowpark pandas is not. + eval_snowpark_pandas_result( + (snow_lhs, snow_rhs), + (pandas_lhs, pandas_rhs), + lambda inputs: getattr(inputs[0], op)(inputs[1]), + expect_exception=True, + expect_exception_match=re.escape( + "Cannot subtract tz-naive and tz-aware datetime-like objects." + ), + assert_exception_equal=False, + except_exception_type=TypeError, + ) + + +class TestDataFrameAndSeriesAxis0: + @pytest.mark.parametrize("op", ["sub", "rsub"]) + @sql_count_checker(query_count=1, join_count=1) + def test_timestamp_dataframe_minus_timestamp_series(self, op): + snow_df, pandas_df = create_test_dfs( + [ + [pd.Timestamp(1, unit="ms"), pd.Timestamp(2, unit="ms")], + [pd.Timestamp(3, unit="ms"), pd.Timestamp(4, unit="ms")], + ] + ) + snow_series, pandas_series = create_test_series( + [ + pd.Timestamp(5, unit="ms"), + pd.Timestamp(6, unit="ms"), + pd.Timestamp(7, unit="ms"), + ] + ) + eval_snowpark_pandas_result( + (snow_df, snow_series), + (pandas_df, pandas_series), + lambda t: getattr(t[0], op)(t[1], axis=0), + ) + + +class TestDataFrameAndSeriesAxis1: + @sql_count_checker( + # One query to materialize the series for the subtraction, and another + # query to materialize the result. + query_count=2 + ) + def test_timestamp_dataframe_minus_timestamp_series(self): + """ + Test subtracting a series of timestamps from a dataframe of timestamps on axis 1. + pandas behavior is incorrect: https://github.com/pandas-dev/pandas/issues/59529 + """ + pandas_df = native_pd.DataFrame( + [ + [pd.Timestamp(1, unit="ms"), pd.Timestamp(2, unit="ms")], + [pd.Timestamp(3, unit="ms"), pd.Timestamp(4, unit="ms")], + ] + ) + pandas_series = native_pd.Series( + [ + pd.Timestamp(5, unit="ms"), + pd.Timestamp(6, unit="ms"), + pd.Timestamp(7, unit="ms"), + ] + ) + with pytest.raises( + TypeError, match="cannot subtract DatetimeArray from ndarray" + ): + pandas_df - pandas_series + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + pd.DataFrame(pandas_df) - pd.Series(pandas_series), + native_pd.DataFrame( + [ + [ + native_pd.Timedelta(milliseconds=-4), + native_pd.Timedelta(milliseconds=-4), + pd.NaT, + ], + [ + native_pd.Timedelta(milliseconds=-2), + native_pd.Timedelta(milliseconds=-2), + pd.NaT, + ], + ] + ), + ) + + @sql_count_checker( + # One query to materialize the series for the subtraction, and another + # query to materialize the result. + query_count=2 + ) + def test_timestamp_series_minus_timestamp_dataframe(self): + """ + Test subtracting a dataframe of timestamps from a series of timestamps. + pandas behavior is incorrect: https://github.com/pandas-dev/pandas/issues/59529 + """ + pandas_df = native_pd.DataFrame( + [ + [pd.Timestamp(1, unit="ms"), pd.Timestamp(2, unit="ms")], + [pd.Timestamp(3, unit="ms"), pd.Timestamp(4, unit="ms")], + ] + ) + pandas_series = native_pd.Series( + [ + pd.Timestamp(5, unit="ms"), + pd.Timestamp(6, unit="ms"), + pd.Timestamp(7, unit="ms"), + ] + ) + with pytest.raises( + np.core._exceptions.UFuncTypeError, + match=re.escape( + "ufunc 'subtract' cannot use operands with types dtype('