Skip to content

Commit

Permalink
Fix query counts and ctor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nkumar committed Aug 9, 2024
1 parent 2b5a772 commit 3824fb6
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 23 deletions.
7 changes: 3 additions & 4 deletions src/snowflake/snowpark/modin/pandas/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def to_datetime(
infer_datetime_format: lib.NoDefault | bool = lib.no_default,
origin: Any = "unix",
cache: bool = True,
) -> Series | DatetimeScalar | NaTType | None:
) -> pd.DatetimeIndex | Series | DatetimeScalar | NaTType | None:
"""
Convert argument to datetime.
Expand Down Expand Up @@ -1459,8 +1459,7 @@ def to_datetime(
parsing):
- scalar: :class:`Timestamp` (or :class:`datetime.datetime`)
- array-like: :class:`~snowflake.snowpark.modin.pandas.Series` with :class:`datetime64` dtype containing
:class:`datetime.datetime` (or
- array-like: :class:`~snowflake.snowpark.modin.pandas.DatetimeIndex` (or
:class: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`object` dtype containing
:class:`datetime.datetime`)
- Series: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`datetime64` dtype (or
Expand Down Expand Up @@ -2170,7 +2169,7 @@ def date_range(
qc = qc.set_index_from_columns(qc.columns.tolist(), include_index=False)
# Set index column name.
qc = qc.set_index_names([name])
return pd.DatetimeIndex(data=qc)
return pd.DatetimeIndex(query_compiler=qc)


@snowpark_pandas_telemetry_standalone_function_decorator
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
# Convert lazy index to Series without pulling the data to client.
if isinstance(data, pd.Index):
query_compiler = data.to_series(index=index, name=name)._query_compiler
query_compiler = query_compiler.reset_index(drop=True)
elif isinstance(data, type(self)):
query_compiler = data._query_compiler.copy()
if index is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,4 +2429,4 @@ def _to_datetime(
origin,
include_index=True,
)
return DatetimeIndex(data=new_qc)
return DatetimeIndex(query_compiler=new_qc)
12 changes: 2 additions & 10 deletions tests/integ/modin/frame/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,11 +945,7 @@ def loc_set_helper(df):
_row_key = key_converter(row_key, df)
df.loc[_row_key] = pd.DataFrame(item)

with SqlCounter(
# one extra query to convert to series to set item
query_count=2 if key_type == "index" else 1,
join_count=expected_join_count,
):
with SqlCounter(query_count=1, join_count=expected_join_count):
eval_snowpark_pandas_result(
pd.DataFrame(native_df), native_df, loc_set_helper, inplace=True
)
Expand All @@ -971,11 +967,7 @@ def loc_set_helper(df):
_row_key = key_converter(row_key, df)
df.loc[_row_key, :] = pd.DataFrame(item)

with SqlCounter(
# one extra query to convert to series to set item
query_count=2 if key_type == "index" else 1,
join_count=expected_join_count,
):
with SqlCounter(query_count=1, join_count=expected_join_count):
eval_snowpark_pandas_result(
pd.DataFrame(native_df), native_df, loc_set_helper, inplace=True
)
Expand Down
6 changes: 1 addition & 5 deletions tests/integ/modin/frame/test_set_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,7 @@ def test_set_index_pass_arrays_duplicate(obj_type1, obj_type2, drop, append, nat
obj_type2 = native_pd.Index
native_keys = [obj_type1(array), obj_type2(array)]

query_count = 4
# one extra query per modin index to create the series and set index
query_count += 1 if obj_type1 == native_pd.Index else 0
query_count += 1 if obj_type2 == native_pd.Index else 0
with SqlCounter(query_count=query_count, join_count=2):
with SqlCounter(query_count=4, join_count=2):
eval_snowpark_pandas_result(
snow_df,
native_df,
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,10 @@ def test_concat_keys_with_none(df1, df2, axis):
)
def test_concat_with_keys_and_names(df1, df2, names, name1, name2, axis):
# One extra query to convert index to native pandas when creating df
with SqlCounter(query_count=0 if name1 is None or axis == 1 else 4, join_count=0):
with SqlCounter(query_count=0 if name1 is None or axis == 1 else 3, join_count=0):
df1 = df1.rename_axis(name1, axis=axis)
# One extra query to convert index to native pandas when creating df
with SqlCounter(query_count=0 if name2 is None or axis == 1 else 4, join_count=0):
with SqlCounter(query_count=0 if name2 is None or axis == 1 else 3, join_count=0):
df2 = df2.rename_axis(name2, axis=axis)

expected_join_count = (
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/tools/test_to_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def test_to_datetime_mixed_datetime_and_string(self):
pytest.param("US/Central"),
],
)
@sql_count_checker(query_count=3)
@sql_count_checker(query_count=2)
def test_to_datetime_dtarr(self, tz):
# DatetimeArray
dti = native_pd.date_range("1965-04-03", periods=19, freq="2W", tz=tz)
Expand Down

0 comments on commit 3824fb6

Please sign in to comment.