Skip to content

Commit

Permalink
SNOW-1649753 Fix a bug while setting None to timedelta column
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-azhan committed Sep 3, 2024
1 parent e9ea11a commit 5ce40df
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 31 deletions.
18 changes: 11 additions & 7 deletions src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,7 +2421,11 @@ def generate_updated_expr_for_existing_col(
elif index_is_frame:
col_obj = iff(index_data_col.is_null(), original_col, col_obj)

col_obj_type = col_obj_type if col_obj_type == origin_col_type else None
col_obj_type = (
origin_col_type
if col_obj_type == origin_col_type or (is_scalar(item) and pd.isna(item))
else None
)

return SnowparkPandasColumn(col_obj, col_obj_type)

Expand Down Expand Up @@ -2726,12 +2730,12 @@ def set_frame_2d_positional(
df_snowflake_quoted_identifier,
).as_(new_snowflake_quoted_identifier)
)
if (
frame.snowflake_quoted_identifier_to_snowpark_pandas_type[
original_snowflake_quoted_identifier
]
== item_type
):
original_type = frame.snowflake_quoted_identifier_to_snowpark_pandas_type[
original_snowflake_quoted_identifier
]
if is_scalar(item) and pd.isna(item):
new_data_column_types.append(original_type)
elif original_type == item_type:
new_data_column_types.append(item_type)
else:
new_data_column_types.append(None)
Expand Down
155 changes: 131 additions & 24 deletions tests/integ/modin/types/test_timedelta_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modin.pandas.utils import is_scalar

from snowflake.snowpark.exceptions import SnowparkSQLException
from tests.integ.modin.sql_counter import SqlCounter
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result


Expand Down Expand Up @@ -163,7 +163,18 @@ def test_df_getitem_timedelta():
)


def test_series_indexing_set_timedelta():
@sql_count_checker(query_count=1, join_count=2)
@pytest.mark.parametrize(
"key, item",
[
[2, pd.Timedelta("2 days 2 hours")], # single value
[slice(2, None), pd.Timedelta("2 days 2 hours")], # multi values
[slice(2, None), None], # multi none values
[slice(2, None), pd.NaT], # multi none values
[slice(None, None), pd.NA], # all none values
],
)
def test_series_indexing_set_timedelta(key, item):
td_s = native_pd.Series(
[
native_pd.Timedelta("1 days 1 hour"),
Expand All @@ -178,21 +189,11 @@ def iloc_set(key, item, s):
s.iloc[key] = item
return s

with SqlCounter(query_count=1, join_count=2):
# single value
eval_snowpark_pandas_result(
snow_td_s.copy(),
td_s,
functools.partial(iloc_set, 2, pd.Timedelta("2 days 2 hours")),
)

with SqlCounter(query_count=1, join_count=2):
# multi values
eval_snowpark_pandas_result(
snow_td_s.copy(),
td_s,
functools.partial(iloc_set, slice(2, None), pd.Timedelta("2 days 2 hours")),
)
eval_snowpark_pandas_result(
snow_td_s.copy(),
td_s,
functools.partial(iloc_set, key, item),
)


def test_df_indexing_set_timedelta():
Expand Down Expand Up @@ -327,7 +328,108 @@ def loc_set(key, item, df):
run_test(key, item, natvie_df=td_int, api=loc_set)


def test_df_indexing_enlargement_timedelta():
def test_df_indexing_set_timedelta_with_none():
td = native_pd.DataFrame(
{
"a": [
native_pd.Timedelta("1 days 1 hour"),
native_pd.Timedelta("2 days 1 minute"),
native_pd.Timedelta("3 days 1 nanoseconds"),
native_pd.Timedelta("100 nanoseconds"),
],
"b": native_pd.timedelta_range("1 hour", "1 day", 4),
"c": [1, 2, 3, 4],
}
)
snow_td = pd.DataFrame(td)

def iloc_set(key, item, df):
df.iloc[key] = item
return df

def run_test(key, item, natvie_df=td, api=iloc_set):
eval_snowpark_pandas_result(
snow_td.copy(), natvie_df.copy(), functools.partial(api, key, item)
)

item = None

with SqlCounter(query_count=1, join_count=2):
# single value
key = (1, 1)
run_test(key, item)

with SqlCounter(query_count=1, join_count=2):
# single column
key = (..., 0)
run_test(key, item)

with SqlCounter(query_count=1, join_count=2):
# multi columns
key = (..., [0, 1])
run_test(key, item)

with SqlCounter(query_count=1, join_count=3):
# multi columns with array
key = (..., [0, 1])
run_test(key, [item] * 2)

def df_set(key, item, df):
df[key] = item
return df

with SqlCounter(query_count=1, join_count=0):
# single column
key = "b"
run_test(key, item, api=df_set)

with SqlCounter(query_count=1, join_count=0):
# multi columns
key = ["a", "b"]
run_test(key, item, api=df_set)

with SqlCounter(query_count=1, join_count=0):
# multi columns with array
key = ["a", "b"]
run_test(key, [item] * 2, api=df_set)

def loc_set(key, item, df):
df.loc[key] = item
return df

with SqlCounter(query_count=1, join_count=1):
# single value
key = (1, "a")
run_test(key, item, api=loc_set)

with SqlCounter(query_count=1, join_count=1):
# single value
key = (1, "a")
run_test(key, None, api=loc_set)

with SqlCounter(query_count=1, join_count=0):
# single column
key = (slice(None, None, None), "a")
run_test(key, item, api=loc_set)

with SqlCounter(query_count=1, join_count=0):
# multi columns
key = (slice(None, None, None), ["a", "b"])
run_test(key, item, api=loc_set)

with SqlCounter(query_count=1, join_count=0):
# multi columns with array
key = (slice(None, None, None), ["a", "b"])
run_test(key, [item] * 2, api=loc_set)

with SqlCounter(query_count=1, join_count=0):
# multi columns with array
key = (slice(None, None, None), ["a", "b"])
run_test(key, [item] * 2, api=loc_set)


@pytest.mark.parametrize("item", [None, pd.Timedelta("1 hour")])
def test_df_indexing_enlargement_timedelta(item):
td = native_pd.DataFrame(
{
"a": [
Expand All @@ -347,7 +449,7 @@ def setitem_enlargement(key, item, df):
return df

key = "x"
item = pd.Timedelta("2 hours")

with SqlCounter(query_count=1, join_count=0):
eval_snowpark_pandas_result(
snow_td.copy(), td.copy(), functools.partial(setitem_enlargement, key, item)
Expand Down Expand Up @@ -384,11 +486,16 @@ def loc_enlargement(key, item, df):
key = (10, slice(None, None, None))

with SqlCounter(query_count=1, join_count=1):
# dtypes does not change while in native pandas, col "c"'s type will change to object
assert_series_equal(
loc_enlargement(key, item, snow_td.copy()).to_pandas().dtypes,
snow_td.dtypes,
)
if pd.isna(item):
eval_snowpark_pandas_result(
snow_td.copy(), td.copy(), functools.partial(loc_enlargement, key, item)
)
else:
# dtypes does not change while in native pandas, col "c"'s type will change to object
assert_series_equal(
loc_enlargement(key, item, snow_td.copy()).to_pandas().dtypes,
snow_td.dtypes,
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 5ce40df

Please sign in to comment.