Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1637102: Support binary operations between timedelta and number. #2200

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
- support indexing with Timedelta data columns.
- support for adding or subtracting timestamps and `Timedelta`.
- support for binary arithmetic between two `Timedelta` values.
- support for binary arithmetic and comparisons between `Timedelta` values and numeric values.
- support for lazy `TimedeltaIndex`.
- support for `pd.to_timedelta`.
- support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`.
Expand Down
177 changes: 145 additions & 32 deletions src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from snowflake.snowpark.column import Column as SnowparkColumn
from snowflake.snowpark.functions import (
ceil,
col,
concat,
dateadd,
Expand All @@ -30,7 +31,10 @@
SnowparkPandasColumn,
TimedeltaType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_object_type
from snowflake.snowpark.modin.plugin._internal.type_utils import (
DataTypeGetter,
infer_object_type,
)
from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit
from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
from snowflake.snowpark.types import (
Expand All @@ -41,6 +45,7 @@
TimestampType,
_FractionalType,
_IntegralType,
_NumericType,
)

NAN_COLUMN = pandas_lit("nan").cast("float")
Expand Down Expand Up @@ -265,22 +270,61 @@ def _op_is_between_two_timedeltas_or_timedelta_and_null(
)


def _is_numeric_non_timedelta_type(datatype: DataType) -> bool:
"""
Whether the datatype is numeric, but not a timedelta type.

Args:
datatype: The datatype

Returns:
bool: Whether the datatype is numeric, but not a timedelta type.
"""
return isinstance(datatype, _NumericType) and not isinstance(
datatype, TimedeltaType
)


def _op_is_between_timedelta_and_numeric(
first_datatype: DataTypeGetter, second_datatype: DataTypeGetter
) -> bool:
"""
Whether the binary operation is between a timedelta and a numeric type.

Returns true if either operand is a timedelta and the other operand is a
non-timedelta numeric.

Args:
First datatype: Getter for first datatype.
Second datatype: Getter for second datatype.

Returns:
bool: Whether the binary operation is between a timedelta and a numeric type.
"""
return (
isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
) or (
_is_numeric_non_timedelta_type(first_datatype())
and isinstance(second_datatype(), TimedeltaType)
)


def compute_binary_op_between_snowpark_columns(
op: str,
first_operand: SnowparkColumn,
first_datatype: Callable[[], DataType],
first_datatype: DataTypeGetter,
second_operand: SnowparkColumn,
second_datatype: Callable[[], DataType],
second_datatype: DataTypeGetter,
) -> SnowparkPandasColumn:
"""
Compute pandas binary operation for two SnowparkColumns
Args:
op: pandas operation
first_operand: SnowparkColumn for lhs
first_datatype: Callable for Snowpark Datatype for lhs, this is lazy so we can avoid pulling the value if
it is not needed.
first_datatype: Callable for Snowpark Datatype for lhs
second_operand: SnowparkColumn for rhs
second_datatype: Callable for Snowpark DateType for rhs, this is lazy so we can avoid pulling the value if
second_datatype: Callable for Snowpark DateType for rhs
it is not needed.

Returns:
Expand Down Expand Up @@ -383,34 +427,104 @@ def compute_binary_op_between_snowpark_columns(
raise np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined]
np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]"))
)
elif _op_is_between_two_timedeltas_or_timedelta_and_null(
elif op in (
"eq",
"ne",
"gt",
"ge",
"lt",
"le",
) and _op_is_between_two_timedeltas_or_timedelta_and_null(
first_datatype(), second_datatype()
) and op in ("eq", "ne", "gt", "ge", "lt", "le", "truediv"):
):
# These operations, when done between timedeltas, work without any
# extra handling in `snowpark_pandas_type` or `binary_op_result_column`.
# They produce outputs that are not timedeltas (e.g. numbers for floordiv
# and truediv, and bools for the comparisons).
pass
elif op == "mul" and (
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
_op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
):
binary_op_result_column = first_operand * second_operand
snowpark_pandas_type = TimedeltaType()
# For `eq` and `ne`, note that Snowflake will consider 1 equal to
# Timedelta(1) because those two have the same representation in Snowflake,
# so we have to compare types in the client.
elif op == "eq" and (
_op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
):
binary_op_result_column = pandas_lit(False)
elif op == "ne" and _op_is_between_timedelta_and_numeric(
first_datatype, second_datatype
):
binary_op_result_column = pandas_lit(True)
elif (
# equal_null and floordiv for timedelta also work without special
# handling, but we need to exclude them from the above case so we catch
# them in an `elif` clause further down.
op not in ("equal_null", "floordiv")
and (
(
isinstance(first_datatype(), TimedeltaType)
and not isinstance(second_datatype(), TimedeltaType)
)
or (
not isinstance(first_datatype(), TimedeltaType)
and isinstance(second_datatype(), TimedeltaType)
op in ("truediv", "floordiv")
and isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
):
binary_op_result_column = floor(first_operand / second_operand)
snowpark_pandas_type = TimedeltaType()
elif (
op == "mod"
and isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
):
binary_op_result_column = ceil(
compute_modulo_between_snowpark_columns(
first_operand, first_datatype(), second_operand, second_datatype()
)
)
snowpark_pandas_type = TimedeltaType()
elif op in ("add", "sub") and (
(
isinstance(first_datatype(), TimedeltaType)
and _is_numeric_non_timedelta_type(second_datatype())
)
or (
_is_numeric_non_timedelta_type(first_datatype())
and isinstance(second_datatype(), TimedeltaType)
)
):
raise TypeError(
"Snowpark pandas does not support addition or subtraction between timedelta values and numeric values."
)
elif op in ("truediv", "floordiv", "mod") and (
_is_numeric_non_timedelta_type(first_datatype())
and isinstance(second_datatype(), TimedeltaType)
):
# We don't support these cases yet.
# TODO(SNOW-1637102): Support this case.
raise TypeError(
"Snowpark pandas does not support dividing numeric values by timedelta values with div (/), mod (%), or floordiv (//)."
)
elif op in (
"add",
"sub",
"truediv",
"floordiv",
"mod",
"gt",
"ge",
"lt",
"le",
"ne",
"eq",
) and (
(
isinstance(first_datatype(), TimedeltaType)
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
and isinstance(second_datatype(), StringType)
)
or (
isinstance(second_datatype(), TimedeltaType)
and isinstance(first_datatype(), StringType)
)
):
# TODO(SNOW-1646604): Support these cases.
ErrorMessage.not_implemented(
f"Snowpark pandas does not yet support the binary operation {op} with a Timedelta column and a non-Timedelta column."
f"Snowpark pandas does not yet support the operation {op} between timedelta and string"
)
elif op in ("gt", "ge", "lt", "le", "pow", "__or__", "__and__") and (
_op_is_between_timedelta_and_numeric(first_datatype, second_datatype)
):
raise TypeError(
f"Snowpark pandas does not support binary operation {op} between timedelta and a non-timedelta type."
)
elif op == "floordiv":
binary_op_result_column = floor(first_operand / second_operand)
Expand Down Expand Up @@ -527,16 +641,15 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool:
def compute_binary_op_between_snowpark_column_and_scalar(
op: str,
first_operand: SnowparkColumn,
datatype: Callable[[], DataType],
datatype: DataTypeGetter,
second_operand: Scalar,
) -> SnowparkPandasColumn:
"""
Compute the binary operation between a Snowpark column and a scalar.
Args:
op: the name of binary operation
first_operand: The SnowparkColumn for lhs
datatype: Callable for Snowpark data type, this is lazy so we can avoid pulling the value if
it is not needed.
datatype: Callable for Snowpark data type
second_operand: Scalar value

Returns:
Expand All @@ -555,15 +668,15 @@ def compute_binary_op_between_scalar_and_snowpark_column(
op: str,
first_operand: Scalar,
second_operand: SnowparkColumn,
datatype: Callable[[], DataType],
datatype: DataTypeGetter,
) -> SnowparkPandasColumn:
"""
Compute the binary operation between a scalar and a Snowpark column.
Args:
op: the name of binary operation
first_operand: Scalar value
second_operand: The SnowparkColumn for rhs
datatype: Callable for Snowpark data type, this is lazy so we can avoid pulling the value if
datatype: Callable for Snowpark data type
it is not needed.

Returns:
Expand All @@ -581,9 +694,9 @@ def first_datatype() -> DataType:
def compute_binary_op_with_fill_value(
op: str,
lhs: SnowparkColumn,
lhs_datatype: Callable[[], DataType],
lhs_datatype: DataTypeGetter,
rhs: SnowparkColumn,
rhs_datatype: Callable[[], DataType],
rhs_datatype: DataTypeGetter,
fill_value: Scalar,
) -> SnowparkPandasColumn:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
from functools import reduce
from typing import Any, Union
from typing import Any, Callable, Union

import numpy as np
import pandas as native_pd
Expand Down Expand Up @@ -77,6 +77,11 @@
_NumericType,
)

# This type is for a function that returns a DataType. By using it to lazily
# get a DataType, we can sometimes defer metadata queries until we need to
# check a type.
DataTypeGetter = Callable[[], DataType]

# The order of this mapping is important because the first match in either
# direction is used by TypeMapper.to_pandas() and TypeMapper.to_snowflake()
NUMPY_SNOWFLAKE_TYPE_PAIRS: list[tuple[Union[type, str], DataType]] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
transpose_empty_df,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import (
DataTypeGetter,
TypeMapper,
column_astype,
infer_object_type,
Expand Down Expand Up @@ -14117,11 +14118,10 @@ def _binary_op_between_dataframe_and_series_along_axis_0(
)
)

# Lazify type map here for calling compute_binary_op_between_snowpark_columns,
# this enables the optimization to pull datatypes only on-demand if needed.
# Lazify type map here for calling compute_binary_op_between_snowpark_columns.
def create_lazy_type_functions(
identifiers: list[str],
) -> list[Callable[[], DataType]]:
) -> list[DataTypeGetter]:
"""
create functions that return datatype on demand for an identifier.
Args:
Expand Down
Loading
Loading