Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1625379 Test coverage for timedelta under modin/integ/frame part 1 #2171

Merged
merged 6 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
#### New Features

- Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `equals`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `mask`, `melt`, `merge`, `nlargest`, `nsmallest`.
- converting non-timedelta to timedelta via `astype`.
- `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`.
- support for subtracting two timestamps to get a Timedelta.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,8 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool:
Returns:
True if given types are equal, False otherwise.
"""
if isinstance(type1, TimedeltaType) and not isinstance(type2, TimedeltaType):
return False
if isinstance(type2, TimedeltaType) and not isinstance(type1, TimedeltaType):
return False
if isinstance(type1, TimedeltaType) or isinstance(type2, TimedeltaType):
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
return type1 == type2
if isinstance(type1, _IntegralType) and isinstance(type2, _IntegralType):
return True
if isinstance(type1, _FractionalType) and isinstance(type2, _FractionalType):
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.indexing_utils import set_frame_2d_labels
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type
from snowflake.snowpark.modin.plugin._internal.utils import (
append_columns,
Expand Down Expand Up @@ -100,6 +103,13 @@ def scalar_isin_expression(
for literal_expr in values
]

# Case 4: If column's and values' data type differs and any of the type is SnowparkPandasType
elif values_dtype != column_dtype and (
isinstance(values_dtype, SnowparkPandasType)
or isinstance(column_dtype, SnowparkPandasType)
):
return pandas_lit(False)

values = array_construct(*values)

# to_variant is a requirement for array_contains, else an error is produced.
Expand Down
24 changes: 24 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,30 @@ def join(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"

def assert_snowpark_pandas_types_match() -> None:
"""If Snowpark pandas types do not match, then a ValueError will be raised."""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_on
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_on
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_on[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)

assert_snowpark_pandas_types_match()
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved

# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
left = left.select_active_columns()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ def get_snowpark_pandas_type_for_pandas_type(
return _type_to_snowpark_pandas_type[pandas_type]()
return None

def type_match(self, value: Any) -> bool:
"""Return True if the value's type matches self."""
val_type = SnowparkPandasType.get_snowpark_pandas_type_for_pandas_type(
type(value)
)
return self == val_type


class SnowparkPandasColumn(NamedTuple):
"""A Snowpark Column that has an optional SnowparkPandasType."""
Expand Down Expand Up @@ -133,6 +140,12 @@ def __init__(self) -> None:
WarningMessage.single_warning(TIMEDELTA_WARNING_MESSAGE)
super().__init__()

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

@staticmethod
def to_pandas(value: int) -> native_pd.Timedelta:
"""
Expand Down
21 changes: 17 additions & 4 deletions src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,16 @@ def _simple_unpivot(
# create the initial set of columns to be retained as identifiers and those
# which will be unpivoted. Collect data type information.
unpivot_quoted_columns = []
unpivot_quoted_column_types = []

ordering_decode_conditions = []
id_col_names = []
id_col_quoted_identifiers = []
for (pandas_label, snowflake_quoted_identifier) in zip(
id_col_types = []
for (pandas_label, snowflake_quoted_identifier, sp_pandas_type) in zip(
frame.data_column_pandas_labels,
frame.data_column_snowflake_quoted_identifiers,
frame.cached_data_column_snowpark_pandas_types,
):
is_id_col = pandas_label in pandas_id_columns
is_var_col = pandas_label in pandas_value_columns
Expand All @@ -752,9 +756,11 @@ def _simple_unpivot(
col(var_quoted) == pandas_lit(pandas_label)
)
unpivot_quoted_columns.append(snowflake_quoted_identifier)
unpivot_quoted_column_types.append(sp_pandas_type)
if is_id_col:
id_col_names.append(pandas_label)
id_col_quoted_identifiers.append(snowflake_quoted_identifier)
id_col_types.append(sp_pandas_type)

# create the case expressions used for the final result set ordering based
# on the column position. This clause will be appled after the unpivot
Expand Down Expand Up @@ -787,7 +793,7 @@ def _simple_unpivot(
pandas_labels=[unquoted_col_name],
)[0]
)
# coalese the values to unpivot and preserve null values This code
# coalesce the values to unpivot and preserve null values This code
# can be removed when UNPIVOT_INCLUDE_NULLS is enabled
unpivot_columns_normalized_types.append(
coalesce(to_variant(c), to_variant(pandas_lit(null_replace_value))).alias(
Expand Down Expand Up @@ -870,6 +876,13 @@ def _simple_unpivot(
var_quoted,
corrected_value_column_name,
]
corrected_value_column_type = None
if len(set(unpivot_quoted_column_types)) == 1:
corrected_value_column_type = unpivot_quoted_column_types[0]
final_snowflake_quoted_col_types = id_col_types + [
None,
corrected_value_column_type,
]

# Create the new frame and compiler
return InternalFrame.create(
Expand All @@ -881,8 +894,8 @@ def _simple_unpivot(
index_column_snowflake_quoted_identifiers=[
ordered_dataframe.row_position_snowflake_quoted_identifier
],
data_column_types=None,
index_column_types=None,
data_column_types=final_snowflake_quoted_col_types,
index_column_types=[None],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ def _shift_values_axis_0(
row_position_quoted_identifier = frame.row_position_snowflake_quoted_identifier

fill_value_dtype = infer_object_type(fill_value)
fill_value = pandas_lit(fill_value) if fill_value is not None else None
fill_value = None if pd.isna(fill_value) else pandas_lit(fill_value)
sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved

def shift_expression_and_type(
quoted_identifier: str, dtype: DataType
Expand Down Expand Up @@ -5757,8 +5757,6 @@ def insert(
Returns:
A new SnowflakeQueryCompiler instance with new column.
"""
self._raise_not_implemented_error_for_timedelta()

if not isinstance(value, SnowflakeQueryCompiler):
# Scalar value
new_internal_frame = self._modin_frame.append_column(
Expand Down Expand Up @@ -5848,7 +5846,9 @@ def move_last_element(arr: list, index: int) -> None:
data_column_snowflake_quoted_identifiers = (
new_internal_frame.data_column_snowflake_quoted_identifiers
)
data_column_types = new_internal_frame.cached_data_column_snowpark_pandas_types
move_last_element(data_column_snowflake_quoted_identifiers, loc)
move_last_element(data_column_types, loc)

new_internal_frame = InternalFrame.create(
ordered_dataframe=new_internal_frame.ordered_dataframe,
Expand All @@ -5857,8 +5857,8 @@ def move_last_element(arr: list, index: int) -> None:
data_column_pandas_index_names=new_internal_frame.data_column_pandas_index_names,
index_column_pandas_labels=new_internal_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=new_internal_frame.index_column_snowflake_quoted_identifiers,
data_column_types=None,
index_column_types=None,
data_column_types=data_column_types,
index_column_types=new_internal_frame.cached_index_column_snowpark_pandas_types,
)
return SnowflakeQueryCompiler(new_internal_frame)

Expand Down Expand Up @@ -6645,8 +6645,6 @@ def melt(
Notes:
melt does not yet handle multiindex or ignore index
"""
self._raise_not_implemented_error_for_timedelta()

if col_level is not None:
raise NotImplementedError(
"Snowpark Pandas doesn't support 'col_level' argument in melt API"
Expand Down Expand Up @@ -6749,8 +6747,6 @@ def merge(
Returns:
SnowflakeQueryCompiler instance with merged result.
"""
self._raise_not_implemented_error_for_timedelta()

if validate:
ErrorMessage.not_implemented(
"Snowpark pandas merge API doesn't yet support 'validate' parameter"
Expand Down Expand Up @@ -9815,6 +9811,10 @@ def _fillna_with_masking(

# case 2: fillna with a method
if method is not None:
# no Snowpark pandas type change in this case
data_column_snowpark_pandas_types = (
self._modin_frame.cached_data_column_snowpark_pandas_types
)
method = FillNAMethod.get_enum_for_string_method(method)
method_is_ffill = method is FillNAMethod.FFILL_METHOD
if axis == 0:
Expand Down Expand Up @@ -9921,6 +9921,7 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
include_index=False,
)
fillna_column_map = {}
data_column_snowpark_pandas_types = []
if columns_mask is not None:
columns_to_ignore = itertools.compress(
self._modin_frame.data_column_pandas_labels,
Expand All @@ -9940,10 +9941,18 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn:
col(id),
coalesce(id, pandas_lit(val)),
)
col_type = self._modin_frame.get_snowflake_type(id)
col_pandas_type = (
col_type
if isinstance(col_type, SnowparkPandasType)
and col_type.type_match(val)
else None
)
data_column_snowpark_pandas_types.append(col_pandas_type)

return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
fillna_column_map
fillna_column_map, data_column_snowpark_pandas_types
).frame
)

Expand Down Expand Up @@ -10217,7 +10226,8 @@ def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler":
}
return SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
diff_label_to_value_map
diff_label_to_value_map,
self._modin_frame.cached_data_column_snowpark_pandas_types,
).frame
)

Expand Down
12 changes: 11 additions & 1 deletion tests/integ/modin/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import pandas as native_pd

RAW_NA_DF_DATA_TEST_CASES = [
({"A": [1, 2, 3], "B": [4, 5, 6]}, "numeric-no"),
Expand All @@ -16,9 +17,18 @@
({"A": [True, 1, "X"], "B": ["Y", 3.14, False]}, "mixed"),
({"A": [True, None, "X"], "B": [None, 3.14, None]}, "mixed-mixed-1"),
({"A": [None, 1, None], "B": ["Y", None, False]}, "mixed-mixed-2"),
(
{
"A": [None, native_pd.Timedelta(2), None],
"B": [native_pd.Timedelta(4), None, native_pd.Timedelta(6)],
},
"timedelta-mixed-1",
),
]

RAW_NA_DF_SERIES_TEST_CASES = [
(list(df_data.values()), test_case)
for (df_data, test_case) in RAW_NA_DF_DATA_TEST_CASES
for (df_data, test_case) in RAW_NA_DF_DATA_TEST_CASES[
:1
] # "timedelta-mixed-1" is not json serializable
]
21 changes: 21 additions & 0 deletions tests/integ/modin/frame/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,24 @@ def test_overwrite_columns_via_assign():
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.assign(a=df["b"], last_col=[10, 11, 12])
)


@sql_count_checker(query_count=2, join_count=1)
def test_assign_basic_timedelta_series():
snow_df, native_df = create_test_dfs(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
columns=native_pd.Index(list("abc"), name="columns"),
index=native_pd.Index([0, 1, 2], name="index"),
)
native_df.columns.names = ["columns"]
native_df.index.names = ["index"]

native_td = native_pd.timedelta_range("1 day", periods=3)

def assign_func(df):
if isinstance(df, pd.DataFrame):
return df.assign(new_col=pd.Series(native_td))
else:
return df.assign(new_col=native_pd.Series(native_td))

eval_snowpark_pandas_result(snow_df, native_df, assign_func)
23 changes: 22 additions & 1 deletion tests/integ/modin/frame/test_bfill_ffill.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.mark.parametrize("func", ["backfill", "bfill", "ffill", "pad"])
@sql_count_checker(query_count=1)
def test_df_func(func):
def test_df_fill(func):
native_df = native_pd.DataFrame(
[
[np.nan, 2, np.nan, 0],
Expand All @@ -31,3 +31,24 @@ def test_df_func(func):
native_df,
lambda df: getattr(df, func)(),
)


@pytest.mark.parametrize("func", ["backfill", "bfill", "ffill", "pad"])
@sql_count_checker(query_count=1)
def test_df_timedelta_fill(func):
native_df = native_pd.DataFrame(
[
[np.nan, 2, np.nan, 0],
[3, 4, np.nan, 1],
[np.nan, np.nan, np.nan, np.nan],
[np.nan, 3, np.nan, 4],
[3, np.nan, 4, np.nan],
],
columns=list("ABCD"),
).astype("timedelta64[ns]")
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: getattr(df, func)(),
)
15 changes: 5 additions & 10 deletions tests/integ/modin/frame/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,10 @@
def base_df() -> native_pd.DataFrame:
return native_pd.DataFrame(
[
[None, None, 3.1, pd.Timestamp("2024-01-01"), [130]],
[
"a",
1,
4.2,
pd.Timestamp("2024-02-01"),
[131],
],
["b", 2, 5.3, pd.Timestamp("2024-03-01"), [132]],
[None, 3, 6.4, pd.Timestamp("2024-04-01"), [133]],
[None, None, 3.1, pd.Timestamp("2024-01-01"), [130], pd.Timedelta(1)],
["a", 1, 4.2, pd.Timestamp("2024-02-01"), [131], pd.Timedelta(11)],
["b", 2, 5.3, pd.Timestamp("2024-03-01"), [132], pd.Timedelta(21)],
[None, 3, 6.4, pd.Timestamp("2024-04-01"), [133], pd.Timedelta(13)],
],
index=pd.MultiIndex.from_tuples(
[
Expand All @@ -64,6 +58,7 @@ def base_df() -> native_pd.DataFrame:
("group_2", "float_col"),
("group_2", "timestamp_col"),
("group_2", "list_col"),
("group_2", "timedelta_col"),
],
names=["column_level1", "column_level2"],
),
Expand Down
Loading
Loading