Skip to content

Commit

Permalink
SNOW-1320449: Support subtracting two timestamps to get a timedelta. (#…
Browse files Browse the repository at this point in the history
…2113)

Signed-off-by: sfc-gh-mvashishtha <[email protected]>
Co-authored-by: Naren Krishna <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha and sfc-gh-nkrishna authored Aug 21, 2024
1 parent 55da0a3 commit d4b4638
Show file tree
Hide file tree
Showing 6 changed files with 657 additions and 87 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- Added support for `Index.__repr__`.
- Added support for `DatetimeIndex.month_name` and `DatetimeIndex.day_name`.
- Added support for `Series.dt.weekday`, `Series.dt.time`, and `DatetimeIndex.time`.
- Added support for subtracting two timestamps to get a Timedelta.

#### Bug Fixes

Expand Down
135 changes: 116 additions & 19 deletions src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,39 @@
import functools
from collections.abc import Hashable
from dataclasses import dataclass
from enum import Enum, auto

import pandas as native_pd
from pandas._typing import Callable, Scalar

from snowflake.snowpark.column import Column as SnowparkColumn
from snowflake.snowpark.functions import col, concat, floor, iff, repeat, when
from snowflake.snowpark.functions import (
col,
concat,
datediff,
floor,
iff,
is_null,
repeat,
when,
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.join_utils import (
JoinOrAlignInternalFrameResult,
)
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasColumn,
TimedeltaType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_object_type
from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit
from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
from snowflake.snowpark.types import (
DataType,
NullType,
StringType,
TimestampTimeZone,
TimestampType,
_FractionalType,
_IntegralType,
)
Expand Down Expand Up @@ -172,13 +189,64 @@ def is_binary_op_supported(op: str) -> bool:
return op in SUPPORTED_BINARY_OPERATIONS


class SubtractionType(Enum):
"""Type of subtraction, i.e. rsub or sub"""

# SUB means first_operand - second_operand, e.g. pd.Series(2).sub(pd.Series(1)) is equal to pd.Series(1)
SUB = auto()
# RSUB means second_operand - first_operand, e.g. pd.Series(2).rsub(pd.Series(1)) is equal to pd.Series(-1)
RSUB = auto()


def _compute_subtraction_between_snowpark_timestamp_columns(
first_operand: SnowparkColumn,
first_datatype: DataType,
second_operand: SnowparkColumn,
second_datatype: DataType,
subtraction_type: SubtractionType,
) -> SnowparkPandasColumn:
"""
Compute subtraction between two snowpark columns.
Args:
first_operand: SnowparkColumn for lhs
first_datatype: Snowpark datatype for lhs
second_operand: SnowparkColumn for rhs
second_datatype: Snowpark datatype for rhs
subtraction_type: Type of subtraction.
"""
if (
first_datatype.tz is TimestampTimeZone.NTZ
and second_datatype.tz is TimestampTimeZone.TZ
) or (
first_datatype.tz is TimestampTimeZone.TZ
and second_datatype.tz is TimestampTimeZone.NTZ
):
raise TypeError("Cannot subtract tz-naive and tz-aware datetime-like objects.")
return SnowparkPandasColumn(
iff(
is_null(first_operand).__or__(is_null(second_operand)),
pandas_lit(native_pd.NaT),
datediff(
"ns",
*(
(first_operand, second_operand)
if subtraction_type is SubtractionType.RSUB
else (second_operand, first_operand)
),
),
),
TimedeltaType(),
)


def compute_binary_op_between_snowpark_columns(
op: str,
first_operand: SnowparkColumn,
first_datatype: Callable[[], DataType],
second_operand: SnowparkColumn,
second_datatype: Callable[[], DataType],
) -> SnowparkColumn:
) -> SnowparkPandasColumn:
"""
Compute pandas binary operation for two SnowparkColumns
Args:
Expand All @@ -191,10 +259,9 @@ def compute_binary_op_between_snowpark_columns(
it is not needed.
Returns:
SnowparkColumn expr for translated pandas operation
SnowparkPandasColumn for translated pandas operation
"""

binary_op_result_column = None
binary_op_result_column, snowpark_pandas_result_type = None, None

# some operators and the data types have to be handled specially to align with pandas
# However, it is difficult to fail early if the arithmetic operator is not compatible
Expand Down Expand Up @@ -272,14 +339,48 @@ def compute_binary_op_between_snowpark_columns(
binary_op_result_column = pandas_lit(False)
else:
binary_op_result_column = first_operand.equal_null(second_operand)

elif (
op in ("rsub", "sub")
and isinstance(first_datatype(), TimestampType)
and isinstance(second_datatype(), NullType)
):
# Timestamp - NULL or NULL - Timestamp raises SQL compilation error,
# but it's valid in pandas and returns NULL.
snowpark_pandas_result_type = NullType()
binary_op_result_column = pandas_lit(None)
elif (
op in ("rsub", "sub")
and isinstance(first_datatype(), NullType)
and isinstance(second_datatype(), TimestampType)
):
# Timestamp - NULL or NULL - Timestamp raises SQL compilation error,
# but it's valid in pandas and returns NULL.
snowpark_pandas_result_type = NullType()
binary_op_result_column = pandas_lit(None)
elif (
op in ("sub", "rsub")
and isinstance(first_datatype(), TimestampType)
and isinstance(second_datatype(), TimestampType)
):
return _compute_subtraction_between_snowpark_timestamp_columns(
first_operand=first_operand,
first_datatype=first_datatype(),
second_operand=second_operand,
second_datatype=second_datatype(),
subtraction_type=SubtractionType.SUB
if op == "sub"
else SubtractionType.RSUB,
)
# If there is no special binary_op_result_column result, it means the operator and
# the data type of the column don't need special handling. Then we get the overloaded
# operator from Snowpark Column class, e.g., __add__ to perform binary operations.
if binary_op_result_column is None:
binary_op_result_column = getattr(first_operand, f"__{op}__")(second_operand)

return binary_op_result_column
return SnowparkPandasColumn(
snowpark_column=binary_op_result_column,
snowpark_pandas_type=snowpark_pandas_result_type,
)


def are_equal_types(type1: DataType, type2: DataType) -> bool:
Expand Down Expand Up @@ -307,7 +408,7 @@ def compute_binary_op_between_snowpark_column_and_scalar(
first_operand: SnowparkColumn,
datatype: Callable[[], DataType],
second_operand: Scalar,
) -> SnowparkColumn:
) -> SnowparkPandasColumn:
"""
Compute the binary operation between a Snowpark column and a scalar.
Args:
Expand All @@ -318,16 +419,14 @@ def compute_binary_op_between_snowpark_column_and_scalar(
second_operand: Scalar value
Returns:
The result as a Snowpark column
SnowparkPandasColumn for translated pandas operation
"""

def second_datatype() -> DataType:
return infer_object_type(second_operand)

second_operand = pandas_lit(second_operand)

return compute_binary_op_between_snowpark_columns(
op, first_operand, datatype, second_operand, second_datatype
op, first_operand, datatype, pandas_lit(second_operand), second_datatype
)


Expand All @@ -336,7 +435,7 @@ def compute_binary_op_between_scalar_and_snowpark_column(
first_operand: Scalar,
second_operand: SnowparkColumn,
datatype: Callable[[], DataType],
) -> SnowparkColumn:
) -> SnowparkPandasColumn:
"""
Compute the binary operation between a scalar and a Snowpark column.
Args:
Expand All @@ -347,16 +446,14 @@ def compute_binary_op_between_scalar_and_snowpark_column(
it is not needed.
Returns:
The result as a Snowpark column
SnowparkPandasColumn for translated pandas operation
"""

def first_datatype() -> DataType:
return infer_object_type(first_operand)

first_operand = pandas_lit(first_operand)

return compute_binary_op_between_snowpark_columns(
op, first_operand, first_datatype, second_operand, datatype
op, pandas_lit(first_operand), first_datatype, second_operand, datatype
)


Expand All @@ -367,7 +464,7 @@ def compute_binary_op_with_fill_value(
rhs: SnowparkColumn,
rhs_datatype: Callable[[], DataType],
fill_value: Scalar,
) -> SnowparkColumn:
) -> SnowparkPandasColumn:
"""
Helper method for performing binary operations.
1. Fills NaN/None values in the lhs and rhs with the given fill_value.
Expand All @@ -392,7 +489,7 @@ def compute_binary_op_with_fill_value(
successful DataFrame alignment, with this value before computation.
Returns:
SnowparkColumn expression for translated pandas operation
SnowparkPandasColumn for translated pandas operation
"""
lhs_cond, rhs_cond = lhs, rhs
if fill_value is not None:
Expand Down
39 changes: 31 additions & 8 deletions src/snowflake/snowpark/modin/plugin/_internal/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,10 @@ def persist_to_temporary_table(self) -> "InternalFrame":
)

def append_column(
self, pandas_label: Hashable, value: SnowparkColumn
self,
pandas_label: Hashable,
value: SnowparkColumn,
value_type: Optional[SnowparkPandasType] = None,
) -> "InternalFrame":
"""
Append a column to this frame. The column is added at the end. For a frame with multiindex column, it
Expand All @@ -935,6 +938,7 @@ def append_column(
Args:
pandas_label: pandas label for column to be inserted.
value: SnowparkColumn.
value_type: The optional SnowparkPandasType for the new column.
Returns:
A new InternalFrame with new column.
Expand Down Expand Up @@ -975,7 +979,8 @@ def append_column(
data_column_pandas_index_names=self.data_column_pandas_index_names,
index_column_pandas_labels=self.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers,
data_column_types=self.cached_data_column_snowpark_pandas_types + [None],
data_column_types=self.cached_data_column_snowpark_pandas_types
+ [value_type],
index_column_types=self.cached_index_column_snowpark_pandas_types,
)

Expand Down Expand Up @@ -1109,6 +1114,9 @@ def get_updated_identifiers(identifiers: list[str]) -> list[str]:
def update_snowflake_quoted_identifiers_with_expressions(
self,
quoted_identifier_to_column_map: dict[str, SnowparkColumn],
data_column_snowpark_pandas_types: Optional[
list[Optional[SnowparkPandasType]]
] = None,
) -> UpdatedInternalFrameResult:
"""
Points Snowflake quoted identifiers to column expression given by `quoted_identifier_to_column_map`.
Expand All @@ -1134,6 +1142,8 @@ def update_snowflake_quoted_identifiers_with_expressions(
existing snowflake quoted identifiers to new Snowpark columns.
As keys of a dictionary, all snowflake column identifiers are unique here and
must be index columns and data columns in the original internal frame.
data_column_snowpark_pandas_types: The optional Snowpark pandas types for the new
expressions, in the order of the keys of quoted_identifier_to_column_map.
Returns:
UpdatedInternalFrameResult: A tuple contaning the new InternalFrame with updated column references, and a mapping
Expand Down Expand Up @@ -1168,10 +1178,16 @@ def update_snowflake_quoted_identifiers_with_expressions(

existing_id_to_new_id_mapping = {}
new_columns = []
for (
existing_identifier,
column_expression,
) in quoted_identifier_to_column_map.items():
new_type_mapping = dict(
self.snowflake_quoted_identifier_to_snowpark_pandas_type
)
if data_column_snowpark_pandas_types is None:
data_column_snowpark_pandas_types = [None] * len(
quoted_identifier_to_column_map
)
for ((existing_identifier, column_expression,), data_column_type) in zip(
quoted_identifier_to_column_map.items(), data_column_snowpark_pandas_types
):
new_identifier = (
self.ordered_dataframe.generate_snowflake_quoted_identifiers(
pandas_labels=[
Expand All @@ -1183,6 +1199,7 @@ def update_snowflake_quoted_identifiers_with_expressions(
)
existing_id_to_new_id_mapping[existing_identifier] = new_identifier
new_columns.append(column_expression)
new_type_mapping[new_identifier] = data_column_type
new_ordered_dataframe = append_columns(
self.ordered_dataframe,
list(existing_id_to_new_id_mapping.values()),
Expand All @@ -1206,10 +1223,16 @@ def update_snowflake_quoted_identifiers_with_expressions(
data_column_pandas_labels=self.data_column_pandas_labels,
data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers,
data_column_pandas_index_names=self.data_column_pandas_index_names,
data_column_types=self.cached_data_column_snowpark_pandas_types,
index_column_pandas_labels=self.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers,
index_column_types=self.cached_index_column_snowpark_pandas_types,
data_column_types=[
new_type_mapping[k]
for k in new_data_column_snowflake_quoted_identifiers
],
index_column_types=[
new_type_mapping[k]
for k in new_index_column_snowflake_quoted_identifiers
],
),
existing_id_to_new_id_mapping,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import inspect
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, NamedTuple, Optional, Union

import numpy as np
import pandas as native_pd

from snowflake.snowpark.column import Column
from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
from snowflake.snowpark.types import LongType

Expand Down Expand Up @@ -99,6 +100,15 @@ def get_snowpark_pandas_type_for_pandas_type(
return _pandas_type_to_snowpark_pandas_type.get(pandas_type, None)


class SnowparkPandasColumn(NamedTuple):
"""A Snowpark Column that has an optional SnowparkPandasType."""

# The Snowpark Column.
snowpark_column: Column
# The SnowparkPandasType for the column, if the type of the column is a SnowparkPandasType.
snowpark_pandas_type: Optional[SnowparkPandasType]


class TimedeltaType(SnowparkPandasType, LongType):
"""
Timedelta represents the difference between two times.
Expand Down
Loading

0 comments on commit d4b4638

Please sign in to comment.