Skip to content

Commit

Permalink
SNOW-1637102: Support binary operations between timedelta and number. (
Browse files Browse the repository at this point in the history
…#2200)

Fixes SNOW-1637102

Signed-off-by: sfc-gh-mvashishtha <[email protected]>
Co-authored-by: Andong Zhan <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha and sfc-gh-azhan authored Sep 4, 2024
1 parent 98a3981 commit 3240e5e
Show file tree
Hide file tree
Showing 5 changed files with 783 additions and 79 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
- support indexing with Timedelta data columns.
- support for adding or subtracting timestamps and `Timedelta`.
- support for binary arithmetic between two `Timedelta` values.
- support for binary arithmetic and comparisons between `Timedelta` values and numeric values.
- support for lazy `TimedeltaIndex`.
- support for `pd.to_timedelta`.
- support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`.
Expand Down
177 changes: 145 additions & 32 deletions src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from snowflake.snowpark.column import Column as SnowparkColumn
from snowflake.snowpark.functions import (
ceil,
col,
concat,
dateadd,
Expand All @@ -30,7 +31,10 @@
SnowparkPandasColumn,
TimedeltaType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_object_type
from snowflake.snowpark.modin.plugin._internal.type_utils import (
DataTypeGetter,
infer_object_type,
)
from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit
from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
from snowflake.snowpark.types import (
Expand All @@ -41,6 +45,7 @@
TimestampType,
_FractionalType,
_IntegralType,
_NumericType,
)

NAN_COLUMN = pandas_lit("nan").cast("float")
Expand Down Expand Up @@ -265,22 +270,61 @@ def _op_is_between_two_timedeltas_or_timedelta_and_null(
)


def _is_numeric_non_timedelta_type(datatype: DataType) -> bool:
"""
Whether the datatype is numeric, but not a timedelta type.
Args:
datatype: The datatype
Returns:
bool: Whether the datatype is numeric, but not a timedelta type.
"""
return isinstance(datatype, _NumericType) and not isinstance(
datatype, TimedeltaType
)


def _op_is_between_timedelta_and_numeric(
first_datatype: DataTypeGetter, second_datatype: DataTypeGetter
) -> bool:
"""
Whether the binary operation is between a timedelta and a numeric type.
Returns true if either operand is a timedelta and the other operand is a
non-timedelta numeric.
Args:
First datatype: Getter for first datatype.
Second datatype: Getter for second datatype.
Returns:
bool: Whether the binary operation is between a timedelta and a numeric type.
"""
return (
isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
) or (
_is_numeric_non_timedelta_type(first_datatype())
and isinstance(second_datatype(), TimedeltaType)
)


def compute_binary_op_between_snowpark_columns(
op: str,
first_operand: SnowparkColumn,
first_datatype: Callable[[], DataType],
first_datatype: DataTypeGetter,
second_operand: SnowparkColumn,
second_datatype: Callable[[], DataType],
second_datatype: DataTypeGetter,
) -> SnowparkPandasColumn:
"""
Compute pandas binary operation for two SnowparkColumns
Args:
op: pandas operation
first_operand: SnowparkColumn for lhs
first_datatype: Callable for Snowpark Datatype for lhs, this is lazy so we can avoid pulling the value if
it is not needed.
first_datatype: Callable for Snowpark Datatype for lhs
second_operand: SnowparkColumn for rhs
second_datatype: Callable for Snowpark DateType for rhs, this is lazy so we can avoid pulling the value if
second_datatype: Callable for Snowpark DateType for rhs
it is not needed.
Returns:
Expand Down Expand Up @@ -383,34 +427,104 @@ def compute_binary_op_between_snowpark_columns(
raise np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined]
np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]"))
)
elif _op_is_between_two_timedeltas_or_timedelta_and_null(
elif op in (
"eq",
"ne",
"gt",
"ge",
"lt",
"le",
) and _op_is_between_two_timedeltas_or_timedelta_and_null(
first_datatype(), second_datatype()
) and op in ("eq", "ne", "gt", "ge", "lt", "le", "truediv"):
):
# These operations, when done between timedeltas, work without any
# extra handling in `snowpark_pandas_type` or `binary_op_result_column`.
# They produce outputs that are not timedeltas (e.g. numbers for floordiv
# and truediv, and bools for the comparisons).
pass
elif op == "mul" and (
_op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
):
binary_op_result_column = first_operand * second_operand
snowpark_pandas_type = TimedeltaType()
# For `eq` and `ne`, note that Snowflake will consider 1 equal to
# Timedelta(1) because those two have the same representation in Snowflake,
# so we have to compare types in the client.
elif op == "eq" and (
_op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
):
binary_op_result_column = pandas_lit(False)
elif op == "ne" and _op_is_between_timedelta_and_numeric(
first_datatype, second_datatype
):
binary_op_result_column = pandas_lit(True)
elif (
# equal_null and floordiv for timedelta also work without special
# handling, but we need to exclude them from the above case so we catch
# them in an `elif` clause further down.
op not in ("equal_null", "floordiv")
and (
(
isinstance(first_datatype(), TimedeltaType)
and not isinstance(second_datatype(), TimedeltaType)
)
or (
not isinstance(first_datatype(), TimedeltaType)
and isinstance(second_datatype(), TimedeltaType)
op in ("truediv", "floordiv")
and isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
):
binary_op_result_column = floor(first_operand / second_operand)
snowpark_pandas_type = TimedeltaType()
elif (
op == "mod"
and isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
):
binary_op_result_column = ceil(
compute_modulo_between_snowpark_columns(
first_operand, first_datatype(), second_operand, second_datatype()
)
)
snowpark_pandas_type = TimedeltaType()
elif op in ("add", "sub") and (
(
isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
)
or (
_is_numeric_non_timedelta_type(first_datatype())
and isinstance(second_datatype(), TimedeltaType)
)
):
raise TypeError(
"Snowpark pandas does not support addition or subtraction between timedelta values and numeric values."
)
elif op in ("truediv", "floordiv", "mod") and (
_is_numeric_non_timedelta_type(first_datatype())
and isinstance(second_datatype(), TimedeltaType)
):
# We don't support these cases yet.
# TODO(SNOW-1637102): Support this case.
raise TypeError(
"Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), or floordiv (//)."
)
elif op in (
"add",
"sub",
"truediv",
"floordiv",
"mod",
"gt",
"ge",
"lt",
"le",
"ne",
"eq",
) and (
(
isinstance(first_datatype(), TimedeltaType)
and isinstance(second_datatype(), StringType)
)
or (
isinstance(second_datatype(), TimedeltaType)
and isinstance(first_datatype(), StringType)
)
):
# TODO(SNOW-1646604): Support these cases.
ErrorMessage.not_implemented(
f"Snowpark pandas does not yet support the binary operation {op} with a Timedelta column and a non-Timedelta column."
f"Snowpark pandas does not yet support the operation {op} between timedelta and string"
)
elif op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and (
_op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
):
raise TypeError(
f"Snowpark pandas does not support binary operation {op} between timedelta and a non-timedelta type."
)
elif op == "floordiv":
binary_op_result_column = floor(first_operand / second_operand)
Expand Down Expand Up @@ -527,16 +641,15 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool:
def compute_binary_op_between_snowpark_column_and_scalar(
op: str,
first_operand: SnowparkColumn,
datatype: Callable[[], DataType],
datatype: DataTypeGetter,
second_operand: Scalar,
) -> SnowparkPandasColumn:
"""
Compute the binary operation between a Snowpark column and a scalar.
Args:
op: the name of binary operation
first_operand: The SnowparkColumn for lhs
datatype: Callable for Snowpark data type, this is lazy so we can avoid pulling the value if
it is not needed.
datatype: Callable for Snowpark data type
second_operand: Scalar value
Returns:
Expand All @@ -555,15 +668,15 @@ def compute_binary_op_between_scalar_and_snowpark_column(
op: str,
first_operand: Scalar,
second_operand: SnowparkColumn,
datatype: Callable[[], DataType],
datatype: DataTypeGetter,
) -> SnowparkPandasColumn:
"""
Compute the binary operation between a scalar and a Snowpark column.
Args:
op: the name of binary operation
first_operand: Scalar value
second_operand: The SnowparkColumn for rhs
datatype: Callable for Snowpark data type, this is lazy so we can avoid pulling the value if
datatype: Callable for Snowpark data type
it is not needed.
Returns:
Expand All @@ -581,9 +694,9 @@ def first_datatype() -> DataType:
def compute_binary_op_with_fill_value(
op: str,
lhs: SnowparkColumn,
lhs_datatype: Callable[[], DataType],
lhs_datatype: DataTypeGetter,
rhs: SnowparkColumn,
rhs_datatype: Callable[[], DataType],
rhs_datatype: DataTypeGetter,
fill_value: Scalar,
) -> SnowparkPandasColumn:
"""
Expand Down
7 changes: 6 additions & 1 deletion src/snowflake/snowpark/modin/plugin/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
from functools import reduce
from typing import Any, Union
from typing import Any, Callable, Union

import numpy as np
import pandas as native_pd
Expand Down Expand Up @@ -77,6 +77,11 @@
_NumericType,
)

# This type is for a function that returns a DataType. By using it to lazily
# get a DataType, we can sometimes defer metadata queries until we need to
# check a type.
DataTypeGetter = Callable[[], DataType]

# The order of this mapping is important because the first match in either
# direction is used by TypeMapper.to_pandas() and TypeMapper.to_snowflake()
NUMPY_SNOWFLAKE_TYPE_PAIRS: list[tuple[Union[type, str], DataType]] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
transpose_empty_df,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import (
DataTypeGetter,
TypeMapper,
column_astype,
infer_object_type,
Expand Down Expand Up @@ -14177,11 +14178,10 @@ def _binary_op_between_dataframe_and_series_along_axis_0(
)
)

# Lazify type map here for calling compute_binary_op_between_snowpark_columns,
# this enables the optimization to pull datatypes only on-demand if needed.
# Lazify type map here for calling compute_binary_op_between_snowpark_columns.
def create_lazy_type_functions(
identifiers: list[str],
) -> list[Callable[[], DataType]]:
) -> list[DataTypeGetter]:
"""
create functions that return datatype on demand for an identifier.
Args:
Expand Down
Loading

0 comments on commit 3240e5e

Please sign in to comment.