Skip to content

Commit

Permalink
SNOW-1625379 Test coverage for timedelta under modin/integ/frame part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-azhan committed Aug 27, 2024
1 parent 0a9bbc7 commit 6fe524b
Show file tree
Hide file tree
Showing 29 changed files with 392 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#### New Features

- Added limited support for the `Timedelta` type, including
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `fillna`, `diff`, `duplicated`, `empty`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `melt`, `nlargest`, `nsmallest`.
- converting non-timedelta to timedelta via `astype`.
- `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`.
- support for subtracting two timestamps to get a Timedelta.
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.indexing_utils import set_frame_2d_labels
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type
from snowflake.snowpark.modin.plugin._internal.utils import (
append_columns,
Expand Down Expand Up @@ -100,6 +103,13 @@ def scalar_isin_expression(
for literal_expr in values
]

# Case 4: If column's and values' data type differs and any of the type is SnowparkPandasType
elif values_dtype != column_dtype and (
isinstance(values_dtype, SnowparkPandasType)
or isinstance(column_dtype, SnowparkPandasType)
):
return pandas_lit(False)

values = array_construct(*values)

# to_variant is a requirement for array_contains, else an error is produced.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ def get_snowpark_pandas_type_for_pandas_type(
return _type_to_snowpark_pandas_type[pandas_type]()
return None

def type_match(self, value: Any) -> bool:
"""Return True if the value's type matches self."""
val_type = SnowparkPandasType.get_snowpark_pandas_type_for_pandas_type(
type(value)
)
return self == val_type


class SnowparkPandasColumn(NamedTuple):
"""A Snowpark Column that has an optional SnowparkPandasType."""
Expand All @@ -111,7 +118,7 @@ class SnowparkPandasColumn(NamedTuple):
snowpark_pandas_type: Optional[SnowparkPandasType]


class TimedeltaType(SnowparkPandasType, LongType):
class TimedeltaType(SnowparkPandasType):
"""
Timedelta represents the difference between two times.
Expand All @@ -133,6 +140,12 @@ def __init__(self) -> None:
WarningMessage.single_warning(TIMEDELTA_WARNING_MESSAGE)
super().__init__()

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

@staticmethod
def to_pandas(value: int) -> native_pd.Timedelta:
"""
Expand Down
21 changes: 17 additions & 4 deletions src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,16 @@ def _simple_unpivot(
# create the initial set of columns to be retained as identifiers and those
# which will be unpivoted. Collect data type information.
unpivot_quoted_columns = []
unpivot_quoted_column_types = []

ordering_decode_conditions = []
id_col_names = []
id_col_quoted_identifiers = []
for (pandas_label, snowflake_quoted_identifier) in zip(
id_col_types = []
for (pandas_label, snowflake_quoted_identifier, sp_pandas_type) in zip(
frame.data_column_pandas_labels,
frame.data_column_snowflake_quoted_identifiers,
frame.cached_data_column_snowpark_pandas_types,
):
is_id_col = pandas_label in pandas_id_columns
is_var_col = pandas_label in pandas_value_columns
Expand All @@ -752,9 +756,11 @@ def _simple_unpivot(
col(var_quoted) == pandas_lit(pandas_label)
)
unpivot_quoted_columns.append(snowflake_quoted_identifier)
unpivot_quoted_column_types.append(sp_pandas_type)
if is_id_col:
id_col_names.append(pandas_label)
id_col_quoted_identifiers.append(snowflake_quoted_identifier)
id_col_types.append(sp_pandas_type)

# create the case expressions used for the final result set ordering based
# on the column position. This clause will be appled after the unpivot
Expand Down Expand Up @@ -787,7 +793,7 @@ def _simple_unpivot(
pandas_labels=[unquoted_col_name],
)[0]
)
# coalese the values to unpivot and preserve null values This code
# coalesce the values to unpivot and preserve null values This code
# can be removed when UNPIVOT_INCLUDE_NULLS is enabled
unpivot_columns_normalized_types.append(
coalesce(to_variant(c), to_variant(pandas_lit(null_replace_value))).alias(
Expand Down Expand Up @@ -870,6 +876,13 @@ def _simple_unpivot(
var_quoted,
corrected_value_column_name,
]
corrected_value_column_type = None
if len(set(unpivot_quoted_column_types)) == 1:
corrected_value_column_type = unpivot_quoted_column_types[0]
final_snowflake_quoted_col_types = id_col_types + [
None,
corrected_value_column_type,
]

# Create the new frame and compiler
return InternalFrame.create(
Expand All @@ -881,8 +894,8 @@ def _simple_unpivot(
index_column_snowflake_quoted_identifiers=[
ordered_dataframe.row_position_snowflake_quoted_identifier
],
data_column_types=None,
index_column_types=None,
data_column_types=final_snowflake_quoted_col_types,
index_column_types=[None],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5753,8 +5753,6 @@ def insert(
Returns:
A new SnowflakeQueryCompiler instance with new column.
"""
self._raise_not_implemented_error_for_timedelta()

if not isinstance(value, SnowflakeQueryCompiler):
# Scalar value
new_internal_frame = self._modin_frame.append_column(
Expand Down Expand Up @@ -5844,7 +5842,9 @@ def move_last_element(arr: list, index: int) -> None:
data_column_snowflake_quoted_identifiers = (
new_internal_frame.data_column_snowflake_quoted_identifiers
)
data_column_types = new_internal_frame.cached_data_column_snowpark_pandas_types
move_last_element(data_column_snowflake_quoted_identifiers, loc)
move_last_element(data_column_types, loc)

new_internal_frame = InternalFrame.create(
ordered_dataframe=new_internal_frame.ordered_dataframe,
Expand All @@ -5853,8 +5853,8 @@ def move_last_element(arr: list, index: int) -> None:
data_column_pandas_index_names=new_internal_frame.data_column_pandas_index_names,
index_column_pandas_labels=new_internal_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=new_internal_frame.index_column_snowflake_quoted_identifiers,
data_column_types=None,
index_column_types=None,
data_column_types=data_column_types,
index_column_types=new_internal_frame.cached_index_column_snowpark_pandas_types,
)
return SnowflakeQueryCompiler(new_internal_frame)

Expand Down Expand Up @@ -6628,8 +6628,6 @@ def melt(
Notes:
melt does not yet handle multiindex or ignore index
"""
self._raise_not_implemented_error_for_timedelta()

if col_level is not None:
raise NotImplementedError(
"Snowpark Pandas doesn't support 'col_level' argument in melt API"
Expand Down Expand Up @@ -6732,8 +6730,6 @@ def merge(
Returns:
SnowflakeQueryCompiler instance with merged result.
"""
self._raise_not_implemented_error_for_timedelta()

if validate:
ErrorMessage.not_implemented(
"Snowpark pandas merge API doesn't yet support 'validate' parameter"
Expand Down Expand Up @@ -9790,6 +9786,10 @@ def _fillna_with_masking(

# case 2: fillna with a method
if method is not None:
# no Snowpark pandas type change in this case
data_column_snowpark_pandas_types = (
self._modin_frame.cached_data_column_snowpark_pandas_types
)
method = FillNAMethod.get_enum_for_string_method(method)
method_is_ffill = method is FillNAMethod.FFILL_METHOD
if axis == 0:
Expand Down Expand Up @@ -9896,6 +9896,7 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
include_index=False,
)
fillna_column_map = {}
data_column_snowpark_pandas_types = []
if columns_mask is not None:
columns_to_ignore = itertools.compress(
self._modin_frame.data_column_pandas_labels,
Expand All @@ -9915,10 +9916,18 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
col(id),
coalesce(id, pandas_lit(val)),
)
col_type = self._modin_frame.get_snowflake_type(id)
col_pandas_type = (
col_type
if isinstance(col_type, SnowparkPandasType)
and col_type.type_match(val)
else None
)
data_column_snowpark_pandas_types.append(col_pandas_type)

return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
fillna_column_map
fillna_column_map, data_column_snowpark_pandas_types
).frame
)

Expand Down Expand Up @@ -10192,7 +10201,8 @@ def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler":
}
return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
diff_label_to_value_map
diff_label_to_value_map,
self._modin_frame.cached_data_column_snowpark_pandas_types,
).frame
)

Expand Down
12 changes: 11 additions & 1 deletion tests/integ/modin/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import pandas as native_pd

RAW_NA_DF_DATA_TEST_CASES = [
({"A": [1, 2, 3], "B": [4, 5, 6]}, "numeric-no"),
Expand All @@ -16,9 +17,18 @@
({"A": [True, 1, "X"], "B": ["Y", 3.14, False]}, "mixed"),
({"A": [True, None, "X"], "B": [None, 3.14, None]}, "mixed-mixed-1"),
({"A": [None, 1, None], "B": ["Y", None, False]}, "mixed-mixed-2"),
(
{
"A": [None, native_pd.Timedelta(2), None],
"B": [native_pd.Timedelta(4), None, native_pd.Timedelta(6)],
},
"timedelta-mixed-1",
),
]

RAW_NA_DF_SERIES_TEST_CASES = [
(list(df_data.values()), test_case)
for (df_data, test_case) in RAW_NA_DF_DATA_TEST_CASES
for (df_data, test_case) in RAW_NA_DF_DATA_TEST_CASES[
:1
] # "timedelta-mixed-1" is not json serializable
]
21 changes: 21 additions & 0 deletions tests/integ/modin/frame/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,24 @@ def test_overwrite_columns_via_assign():
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.assign(a=df["b"], last_col=[10, 11, 12])
)


@sql_count_checker(query_count=2, join_count=1)
def test_assign_basic_timedelta_series():
snow_df, native_df = create_test_dfs(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
columns=native_pd.Index(list("abc"), name="columns"),
index=native_pd.Index([0, 1, 2], name="index"),
)
native_df.columns.names = ["columns"]
native_df.index.names = ["index"]

native_td = native_pd.timedelta_range("1 day", periods=3)

def assign_func(df):
if isinstance(df, pd.DataFrame):
return df.assign(new_col=pd.Series(native_td))
else:
return df.assign(new_col=native_pd.Series(native_td))

eval_snowpark_pandas_result(snow_df, native_df, assign_func)
23 changes: 22 additions & 1 deletion tests/integ/modin/frame/test_bfill_ffill.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.mark.parametrize("func", ["backfill", "bfill", "ffill", "pad"])
@sql_count_checker(query_count=1)
def test_df_func(func):
def test_df_fill(func):
native_df = native_pd.DataFrame(
[
[np.nan, 2, np.nan, 0],
Expand All @@ -31,3 +31,24 @@ def test_df_func(func):
native_df,
lambda df: getattr(df, func)(),
)


@pytest.mark.parametrize("func", ["backfill", "bfill", "ffill", "pad"])
@sql_count_checker(query_count=1)
def test_df_timedelta_fill(func):
native_df = native_pd.DataFrame(
[
[np.nan, 2, np.nan, 0],
[3, 4, np.nan, 1],
[np.nan, np.nan, np.nan, np.nan],
[np.nan, 3, np.nan, 4],
[3, np.nan, 4, np.nan],
],
columns=list("ABCD"),
).astype("timedelta64[ns]")
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: getattr(df, func)(),
)
15 changes: 5 additions & 10 deletions tests/integ/modin/frame/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,10 @@
def base_df() -> native_pd.DataFrame:
return native_pd.DataFrame(
[
[None, None, 3.1, pd.Timestamp("2024-01-01"), [130]],
[
"a",
1,
4.2,
pd.Timestamp("2024-02-01"),
[131],
],
["b", 2, 5.3, pd.Timestamp("2024-03-01"), [132]],
[None, 3, 6.4, pd.Timestamp("2024-04-01"), [133]],
[None, None, 3.1, pd.Timestamp("2024-01-01"), [130], pd.Timedelta(1)],
["a", 1, 4.2, pd.Timestamp("2024-02-01"), [131], pd.Timedelta(11)],
["b", 2, 5.3, pd.Timestamp("2024-03-01"), [132], pd.Timedelta(21)],
[None, 3, 6.4, pd.Timestamp("2024-04-01"), [133], pd.Timedelta(13)],
],
index=pd.MultiIndex.from_tuples(
[
Expand All @@ -54,6 +48,7 @@ def base_df() -> native_pd.DataFrame:
("group_2", "float_col"),
("group_2", "timestamp_col"),
("group_2", "list_col"),
("group_2", "timedelta_col"),
],
names=["column_level1", "column_level2"],
),
Expand Down
14 changes: 14 additions & 0 deletions tests/integ/modin/frame/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ def test_df_diff_bool_df(periods):
eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.diff(periods=periods))


@sql_count_checker(query_count=1)
@pytest.mark.parametrize("periods", [0, 1])
def test_df_diff_timedelta_df(periods):
native_df = native_pd.DataFrame(
np.arange(NUM_ROWS_TALL_DF * NUM_COLS_TALL_DF).reshape(
(NUM_ROWS_TALL_DF, NUM_COLS_TALL_DF)
),
columns=["A", "B", "C", "D"],
)
native_df = native_df.astype({"A": "timedelta64[ns]", "C": "timedelta64[ns]"})
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.diff(periods=periods))


@sql_count_checker(query_count=1)
@pytest.mark.parametrize("periods", [0, 1])
def test_df_diff_int_and_bool_df(periods):
Expand Down
12 changes: 12 additions & 0 deletions tests/integ/modin/frame/test_drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def test_drop_list_like(native_df, labels):
eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.drop(labels, axis=1))


@pytest.mark.parametrize(
"labels", [Index(["red", "green"]), np.array(["red", "green"])]
)
@sql_count_checker(query_count=1)
def test_drop_timedelta(native_df, labels):
native_df_dt = native_df.astype({"red": "timedelta64[ns]"})
snow_df = pd.DataFrame(native_df_dt)
eval_snowpark_pandas_result(
snow_df, native_df_dt, lambda df: df.drop(labels, axis=1)
)


@pytest.mark.parametrize(
"labels, axis, expected_query_count",
[
Expand Down
1 change: 1 addition & 0 deletions tests/integ/modin/frame/test_dropna.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_dropna_df():
"name": ["Alfred", "Batman", "Catwoman"],
"toy": [np.nan, "Batmobile", "Bullwhip"],
"born": [pd.NaT, pd.Timestamp("1940-04-25"), pd.NaT],
"dt": [pd.NaT, pd.Timedelta(1), pd.NaT],
}
)

Expand Down
Loading

0 comments on commit 6fe524b

Please sign in to comment.