From a5d06d9dfc0e445de0a458272ee060e2b4d3d592 Mon Sep 17 00:00:00 2001 From: sfc-gh-mvashishtha Date: Fri, 30 Aug 2024 16:28:34 -0700 Subject: [PATCH] SNOW-1641729: Support diff() for timestamp columns. Signed-off-by: sfc-gh-mvashishtha --- CHANGELOG.md | 1 + .../compiler/snowflake_query_compiler.py | 57 +++++++++++++------ tests/integ/modin/frame/test_diff.py | 17 ++++++ tests/integ/modin/series/test_diff.py | 4 +- tests/integ/modin/test_timedelta_ops.py | 29 +++++----- 5 files changed, 75 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0767d5d3a0a..59bb4947da1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ - support for lazy `TimedeltaIndex`. - support for `pd.to_timedelta`. - support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`. + - support for `diff()` with timestamp columns on `axis=0` and `axis=1` - Added support for index's arithmetic and comparison operators. - Added support for `Series.dt.round`. - Added documentation pages for `DatetimeIndex`. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 50ce5e71310..273cb1d0388 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -10189,7 +10189,7 @@ def _make_discrete_difference_expression( periods: int, column_position: int, axis: int, - ) -> SnowparkColumn: + ) -> SnowparkPandasColumn: """ Helper function to generate Columns for discrete difference. @@ -10207,9 +10207,10 @@ def _make_discrete_difference_expression( Returns ------- - SnowparkColumn - An expression to generate the discrete difference along the specified axis, with the - specified period, for the column specified by `column_position`. + SnowparkPandasColumn + An column representing the discrete difference along the specified + axis, with the specified period, for the column specified by + `column_position`. """ # If periods is 0, we are doing a subtraction with self (or XOR in case of bool # dtype). In this case, even if axis is 0, we prefer to use the col-wise code, @@ -10239,15 +10240,25 @@ def _make_discrete_difference_expression( self._modin_frame.ordering_column_snowflake_quoted_identifiers ) ) - return (col1 | col2) & (not_(col1 & col2)) + return SnowparkPandasColumn( + snowpark_column=(col1 | col2) & (not_(col1 & col2)), + snowpark_pandas_type=None, + ) else: - return col(snowflake_quoted_identifier) - func_for_other( - snowflake_quoted_identifier, offset=abs(periods) - ).over( - Window.order_by( - self._modin_frame.ordering_column_snowflake_quoted_identifiers - ) + return compute_binary_op_between_snowpark_columns( + "sub", + col(snowflake_quoted_identifier), + lambda: column_datatype, + func_for_other( + snowflake_quoted_identifier, offset=abs(periods) + ).over( + Window.order_by( + self._modin_frame.ordering_column_snowflake_quoted_identifiers + ) + ), + lambda: column_datatype, ) + else: # periods is the number of columns to *go back*. periods *= -1 @@ -10258,7 +10269,9 @@ def _make_discrete_difference_expression( if other_column_position < 0 or other_column_position >= len( self._modin_frame.data_column_snowflake_quoted_identifiers ): - return pandas_lit(np.nan) + return SnowparkPandasColumn( + snowpark_column=pandas_lit(np.nan), snowpark_pandas_type=None + ) # In this case, we are at a column that does have a match, so we must do dtype checking # and then generate the expression. else: @@ -10285,13 +10298,21 @@ def _make_discrete_difference_expression( if isinstance(col1_dtype, BooleanType) and isinstance( col2_dtype, BooleanType ): - return (col1 | col2) & (not_(col1 & col2)) + return SnowparkPandasColumn( + (col1 | col2) & (not_(col1 & col2)), snowpark_pandas_type=None + ) else: if isinstance(col1_dtype, BooleanType): col1 = cast(col1, IntegerType()) if isinstance(col2_dtype, BooleanType): col2 = cast(col2, IntegerType()) - return col1 - col2 + return compute_binary_op_between_snowpark_columns( + "sub", + col1, + lambda: col1_dtype, + col2, + lambda: col2_dtype, + ) def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler": """ @@ -10312,8 +10333,12 @@ def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler": } return SnowflakeQueryCompiler( self._modin_frame.update_snowflake_quoted_identifiers_with_expressions( - diff_label_to_value_map, - self._modin_frame.cached_data_column_snowpark_pandas_types, + quoted_identifier_to_column_map={ + k: v.snowpark_column for k, v in diff_label_to_value_map.items() + }, + data_column_snowpark_pandas_types=[ + a.snowpark_pandas_type for a in diff_label_to_value_map.values() + ], ).frame ) diff --git a/tests/integ/modin/frame/test_diff.py b/tests/integ/modin/frame/test_diff.py index 185b2eab89e..493b108be89 100644 --- a/tests/integ/modin/frame/test_diff.py +++ b/tests/integ/modin/frame/test_diff.py @@ -154,6 +154,23 @@ def test_df_diff_timedelta_df(periods): eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.diff(periods=periods)) +@sql_count_checker(query_count=1) +@pytest.mark.parametrize("periods", [-1, 0, 1]) +@pytest.mark.parametrize("axis", [0, 1]) +def test_df_diff_datetime_df(periods, axis): + native_df = native_pd.DataFrame( + np.arange(NUM_ROWS_TALL_DF * NUM_COLS_TALL_DF).reshape( + (NUM_ROWS_TALL_DF, NUM_COLS_TALL_DF) + ), + columns=["A", "B", "C", "D"], + ) + native_df = native_df.astype("datetime64[ns]") + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: df.diff(periods=periods, axis=axis) + ) + + @sql_count_checker(query_count=1) @pytest.mark.parametrize("periods", [0, 1]) def test_df_diff_int_and_bool_df(periods): diff --git a/tests/integ/modin/series/test_diff.py b/tests/integ/modin/series/test_diff.py index 7878195ba6a..1d6412d11c2 100644 --- a/tests/integ/modin/series/test_diff.py +++ b/tests/integ/modin/series/test_diff.py @@ -50,7 +50,9 @@ def test_series_diff_invalid_periods_negative(): snow_ser.diff("1").to_pandas() -@pytest.mark.parametrize("ser_type", [bool, int, object]) +@pytest.mark.parametrize( + "ser_type", [bool, int, object, "timedelta64[ns]", "datetime64[ns]"] +) @pytest.mark.parametrize( "periods", [ diff --git a/tests/integ/modin/test_timedelta_ops.py b/tests/integ/modin/test_timedelta_ops.py index c60b91b3273..4bf4e78b12c 100644 --- a/tests/integ/modin/test_timedelta_ops.py +++ b/tests/integ/modin/test_timedelta_ops.py @@ -6,13 +6,15 @@ import modin.pandas as pd import numpy as np import pandas as native_pd -import pytest from pandas import Timestamp import snowflake.snowpark.modin.plugin # noqa: F401 -from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_series_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) TIME_DATA1 = { "CREATED_AT": ["2018-8-26 15:09:02", "2018-8-25 11:10:07", "2018-8-27 12:05:00"], @@ -70,8 +72,8 @@ def test_insert_datetime_difference(): ) -@sql_count_checker(query_count=0) -def test_diff_timestamp_column_to_get_timedelta_negative(): +@sql_count_checker(query_count=1) +def test_diff_timestamp_column_to_get_timedelta(): data = { "Country": ["A", "B", "C", "D", "E"], "Agreement Signing Date": [ @@ -82,14 +84,9 @@ def test_diff_timestamp_column_to_get_timedelta_negative(): pd.Timestamp("2017-08-09"), ], } - snow_df = pd.DataFrame(data) - native_df = native_pd.DataFrame(data) - # TODO SNOW-1641729: remove Exception raised when TimeDelta is implemented - with pytest.raises(SnowparkSQLException): - eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.set_index("Country") - .diff() - .rename(columns={"Agreement Signing Date": "DiffDaysPrevAggrement"}), - ) + eval_snowpark_pandas_result( + *create_test_dfs(data), + lambda df: df.set_index("Country") + .diff() + .rename(columns={"Agreement Signing Date": "DiffDaysPrevAggrement"}), + )