Skip to content

Commit

Permalink
SNOW-1641729: Support diff() for timestamp columns.
Browse files Browse the repository at this point in the history
Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha committed Aug 30, 2024
1 parent 4bcd987 commit a5d06d9
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
- support for lazy `TimedeltaIndex`.
- support for `pd.to_timedelta`.
- support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`.
- support for `diff()` with timestamp columns on `axis=0` and `axis=1`
- Added support for index's arithmetic and comparison operators.
- Added support for `Series.dt.round`.
- Added documentation pages for `DatetimeIndex`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10189,7 +10189,7 @@ def _make_discrete_difference_expression(
periods: int,
column_position: int,
axis: int,
) -> SnowparkColumn:
) -> SnowparkPandasColumn:
"""
Helper function to generate Columns for discrete difference.

Expand All @@ -10207,9 +10207,10 @@ def _make_discrete_difference_expression(

Returns
-------
SnowparkColumn
An expression to generate the discrete difference along the specified axis, with the
specified period, for the column specified by `column_position`.
SnowparkPandasColumn
An column representing the discrete difference along the specified
axis, with the specified period, for the column specified by
`column_position`.
"""
# If periods is 0, we are doing a subtraction with self (or XOR in case of bool
# dtype). In this case, even if axis is 0, we prefer to use the col-wise code,
Expand Down Expand Up @@ -10239,15 +10240,25 @@ def _make_discrete_difference_expression(
self._modin_frame.ordering_column_snowflake_quoted_identifiers
)
)
return (col1 | col2) & (not_(col1 & col2))
return SnowparkPandasColumn(
snowpark_column=(col1 | col2) & (not_(col1 & col2)),
snowpark_pandas_type=None,
)
else:
return col(snowflake_quoted_identifier) - func_for_other(
snowflake_quoted_identifier, offset=abs(periods)
).over(
Window.order_by(
self._modin_frame.ordering_column_snowflake_quoted_identifiers
)
return compute_binary_op_between_snowpark_columns(
"sub",
col(snowflake_quoted_identifier),
lambda: column_datatype,
func_for_other(
snowflake_quoted_identifier, offset=abs(periods)
).over(
Window.order_by(
self._modin_frame.ordering_column_snowflake_quoted_identifiers
)
),
lambda: column_datatype,
)

else:
# periods is the number of columns to *go back*.
periods *= -1
Expand All @@ -10258,7 +10269,9 @@ def _make_discrete_difference_expression(
if other_column_position < 0 or other_column_position >= len(
self._modin_frame.data_column_snowflake_quoted_identifiers
):
return pandas_lit(np.nan)
return SnowparkPandasColumn(
snowpark_column=pandas_lit(np.nan), snowpark_pandas_type=None
)
# In this case, we are at a column that does have a match, so we must do dtype checking
# and then generate the expression.
else:
Expand All @@ -10285,13 +10298,21 @@ def _make_discrete_difference_expression(
if isinstance(col1_dtype, BooleanType) and isinstance(
col2_dtype, BooleanType
):
return (col1 | col2) & (not_(col1 & col2))
return SnowparkPandasColumn(
(col1 | col2) & (not_(col1 & col2)), snowpark_pandas_type=None
)
else:
if isinstance(col1_dtype, BooleanType):
col1 = cast(col1, IntegerType())
if isinstance(col2_dtype, BooleanType):
col2 = cast(col2, IntegerType())
return col1 - col2
return compute_binary_op_between_snowpark_columns(
"sub",
col1,
lambda: col1_dtype,
col2,
lambda: col2_dtype,
)

def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler":
"""
Expand All @@ -10312,8 +10333,12 @@ def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler":
}
return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
diff_label_to_value_map,
self._modin_frame.cached_data_column_snowpark_pandas_types,
quoted_identifier_to_column_map={
k: v.snowpark_column for k, v in diff_label_to_value_map.items()
},
data_column_snowpark_pandas_types=[
a.snowpark_pandas_type for a in diff_label_to_value_map.values()
],
).frame
)

Expand Down
17 changes: 17 additions & 0 deletions tests/integ/modin/frame/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def test_df_diff_timedelta_df(periods):
eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.diff(periods=periods))


@sql_count_checker(query_count=1)
@pytest.mark.parametrize("periods", [-1, 0, 1])
@pytest.mark.parametrize("axis", [0, 1])
def test_df_diff_datetime_df(periods, axis):
native_df = native_pd.DataFrame(
np.arange(NUM_ROWS_TALL_DF * NUM_COLS_TALL_DF).reshape(
(NUM_ROWS_TALL_DF, NUM_COLS_TALL_DF)
),
columns=["A", "B", "C", "D"],
)
native_df = native_df.astype("datetime64[ns]")
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.diff(periods=periods, axis=axis)
)


@sql_count_checker(query_count=1)
@pytest.mark.parametrize("periods", [0, 1])
def test_df_diff_int_and_bool_df(periods):
Expand Down
4 changes: 3 additions & 1 deletion tests/integ/modin/series/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_series_diff_invalid_periods_negative():
snow_ser.diff("1").to_pandas()


@pytest.mark.parametrize("ser_type", [bool, int, object])
@pytest.mark.parametrize(
"ser_type", [bool, int, object, "timedelta64[ns]", "datetime64[ns]"]
)
@pytest.mark.parametrize(
"periods",
[
Expand Down
29 changes: 13 additions & 16 deletions tests/integ/modin/test_timedelta_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest
from pandas import Timestamp

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark.exceptions import SnowparkSQLException
from tests.integ.modin.sql_counter import sql_count_checker
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result
from tests.integ.modin.utils import (
assert_series_equal,
create_test_dfs,
eval_snowpark_pandas_result,
)

TIME_DATA1 = {
"CREATED_AT": ["2018-8-26 15:09:02", "2018-8-25 11:10:07", "2018-8-27 12:05:00"],
Expand Down Expand Up @@ -70,8 +72,8 @@ def test_insert_datetime_difference():
)


@sql_count_checker(query_count=0)
def test_diff_timestamp_column_to_get_timedelta_negative():
@sql_count_checker(query_count=1)
def test_diff_timestamp_column_to_get_timedelta():
data = {
"Country": ["A", "B", "C", "D", "E"],
"Agreement Signing Date": [
Expand All @@ -82,14 +84,9 @@ def test_diff_timestamp_column_to_get_timedelta_negative():
pd.Timestamp("2017-08-09"),
],
}
snow_df = pd.DataFrame(data)
native_df = native_pd.DataFrame(data)
# TODO SNOW-1641729: remove Exception raised when TimeDelta is implemented
with pytest.raises(SnowparkSQLException):
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.set_index("Country")
.diff()
.rename(columns={"Agreement Signing Date": "DiffDaysPrevAggrement"}),
)
eval_snowpark_pandas_result(
*create_test_dfs(data),
lambda df: df.set_index("Country")
.diff()
.rename(columns={"Agreement Signing Date": "DiffDaysPrevAggrement"}),
)

0 comments on commit a5d06d9

Please sign in to comment.