Skip to content

Commit

Permalink
Fix query counts and ctor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nkumar committed Aug 10, 2024
1 parent 2b5a772 commit 159c81b
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 57 deletions.
7 changes: 3 additions & 4 deletions src/snowflake/snowpark/modin/pandas/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def to_datetime(
infer_datetime_format: lib.NoDefault | bool = lib.no_default,
origin: Any = "unix",
cache: bool = True,
) -> Series | DatetimeScalar | NaTType | None:
) -> pd.DatetimeIndex | Series | DatetimeScalar | NaTType | None:
"""
Convert argument to datetime.
Expand Down Expand Up @@ -1459,8 +1459,7 @@ def to_datetime(
parsing):
- scalar: :class:`Timestamp` (or :class:`datetime.datetime`)
- array-like: :class:`~snowflake.snowpark.modin.pandas.Series` with :class:`datetime64` dtype containing
:class:`datetime.datetime` (or
- array-like: :class:`~snowflake.snowpark.modin.pandas.DatetimeIndex` (or
:class: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`object` dtype containing
:class:`datetime.datetime`)
- Series: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`datetime64` dtype (or
Expand Down Expand Up @@ -2170,7 +2169,7 @@ def date_range(
qc = qc.set_index_from_columns(qc.columns.tolist(), include_index=False)
# Set index column name.
qc = qc.set_index_names([name])
return pd.DatetimeIndex(data=qc)
return pd.DatetimeIndex(query_compiler=qc)


@snowpark_pandas_telemetry_standalone_function_decorator
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
# Convert lazy index to Series without pulling the data to client.
if isinstance(data, pd.Index):
query_compiler = data.to_series(index=index, name=name)._query_compiler
query_compiler = query_compiler.reset_index(drop=True)
elif isinstance(data, type(self)):
query_compiler = data._query_compiler.copy()
if index is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,4 +2429,4 @@ def _to_datetime(
origin,
include_index=True,
)
return DatetimeIndex(data=new_qc)
return DatetimeIndex(query_compiler=new_qc)
39 changes: 9 additions & 30 deletions tests/integ/modin/frame/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_df_loc_get_tuple_key(
snow_row = row

query_count = 1
if is_scalar(row) or isinstance(row, tuple) or isinstance(row, native_pd.Index):
if is_scalar(row) or isinstance(row, tuple):
query_count = 2

with SqlCounter(
Expand Down Expand Up @@ -945,11 +945,7 @@ def loc_set_helper(df):
_row_key = key_converter(row_key, df)
df.loc[_row_key] = pd.DataFrame(item)

with SqlCounter(
# one extra query to convert to series to set item
query_count=2 if key_type == "index" else 1,
join_count=expected_join_count,
):
with SqlCounter(query_count=1, join_count=expected_join_count):
eval_snowpark_pandas_result(
pd.DataFrame(native_df), native_df, loc_set_helper, inplace=True
)
Expand All @@ -971,11 +967,7 @@ def loc_set_helper(df):
_row_key = key_converter(row_key, df)
df.loc[_row_key, :] = pd.DataFrame(item)

with SqlCounter(
# one extra query to convert to series to set item
query_count=2 if key_type == "index" else 1,
join_count=expected_join_count,
):
with SqlCounter(query_count=1, join_count=expected_join_count):
eval_snowpark_pandas_result(
pd.DataFrame(native_df), native_df, loc_set_helper, inplace=True
)
Expand Down Expand Up @@ -1153,9 +1145,6 @@ def loc_set_helper(df):
query_count, join_count = 1, 2
if not all(isinstance(rk_val, bool) for rk_val in row_key):
join_count += 2
# one extra query to convert to native pandas to initialize series and set item
if key_type == "index":
query_count = 2
if isinstance(col_key, native_pd.Series):
query_count += 1
with SqlCounter(query_count=query_count, join_count=join_count):
Expand Down Expand Up @@ -1235,10 +1224,6 @@ def loc_set_helper(df):
if isinstance(col_key, native_pd.Series):
query_count += 1

# one extra query to convert to native pandas to initialize series and set item
if key_type == "index":
query_count += 1

with SqlCounter(
query_count=query_count,
join_count=join_count,
Expand Down Expand Up @@ -1316,8 +1301,7 @@ def loc_set_helper(df):
else:
df.loc[row_key, :] = pd.DataFrame(item)

# one extra query to convert index to native pandas to initialize series and set item
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=4):
with SqlCounter(query_count=1, join_count=4):
if item.index.has_duplicates:
# pandas fails to update duplicated rows with duplicated item
with pytest.raises(
Expand Down Expand Up @@ -1641,8 +1625,7 @@ def loc_helper(df):

return _df.loc[_key]

# one extra query to convert index to native pandas to initialize series and set item
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
default_index_snowpark_pandas_df,
default_index_native_df,
Expand Down Expand Up @@ -1985,8 +1968,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None):
)

# default index
# one extra query to convert to series to set item
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
default_index_snowpark_pandas_df,
default_index_native_df,
Expand All @@ -2000,8 +1982,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None):
"index"
)
non_default_index_snowpark_pandas_df = pd.DataFrame(non_default_index_native_df)
# one extra query to convert to series to set item
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
non_default_index_snowpark_pandas_df,
non_default_index_native_df,
Expand All @@ -2021,8 +2002,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None):
]
)
dup_snowpandas_df = pd.DataFrame(dup_native_df)
# one extra query to convert to series to set item
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
dup_snowpandas_df,
dup_native_df,
Expand All @@ -2047,8 +2027,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None):
]
)
dup_snowpandas_df = pd.DataFrame(dup_native_df)
# one extra query to convert to series to set item
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
dup_snowpandas_df,
dup_native_df,
Expand Down
6 changes: 1 addition & 5 deletions tests/integ/modin/frame/test_set_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,7 @@ def test_set_index_pass_arrays_duplicate(obj_type1, obj_type2, drop, append, nat
obj_type2 = native_pd.Index
native_keys = [obj_type1(array), obj_type2(array)]

query_count = 4
# one extra query per modin index to create the series and set index
query_count += 1 if obj_type1 == native_pd.Index else 0
query_count += 1 if obj_type2 == native_pd.Index else 0
with SqlCounter(query_count=query_count, join_count=2):
with SqlCounter(query_count=4, join_count=2):
eval_snowpark_pandas_result(
snow_df,
native_df,
Expand Down
24 changes: 10 additions & 14 deletions tests/integ/modin/series/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def loc_helper(ser):
return _ser.loc[_key]

default_index_series = pd.Series(default_index_native_series)
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
default_index_series,
default_index_native_series,
Expand Down Expand Up @@ -480,7 +480,7 @@ def type_convert(key, is_snow_type):
# Note: here number of queries are 2 due to the data type of the series is variant and to_pandas needs to call
# typeof to get the value types
# TODO: SNOW-933782 optimize to_pandas for variant columns to only fire one query
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
default_index_snowpark_pandas_series,
default_index_native_series,
Expand All @@ -497,7 +497,7 @@ def type_convert(key, is_snow_type):
non_default_index_snowpark_pandas_series = pd.Series(
non_default_index_native_series
)
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
non_default_index_snowpark_pandas_series,
non_default_index_native_series,
Expand All @@ -514,7 +514,7 @@ def type_convert(key, is_snow_type):
]
)
dup_snowpandas_series = pd.Series(dup_native_series)
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
dup_snowpandas_series,
dup_native_series,
Expand All @@ -539,7 +539,7 @@ def type_convert(key, is_snow_type):
]
)
dup_snowpandas_series = pd.Series(dup_native_series)
with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1):
with SqlCounter(query_count=1, join_count=1):
eval_snowpark_pandas_result(
dup_snowpandas_series,
dup_native_series,
Expand Down Expand Up @@ -776,15 +776,15 @@ def loc_set_helper(s):
s.loc[_row_key] = _item

query_count = 1
# 6 extra queries: sum of two cases below
# 5 extra queries: sum of two cases below
if item_type.startswith("index") and key_type.startswith("index"):
query_count = 7
query_count = 6
# 4 extra queries: 1 query to convert item index to pandas in loc_set_helper, 2 for iter, and 1 for to_list
elif item_type.startswith("index"):
query_count = 5
# 2 extra queries: 1 query to convert key index to pandas in loc_set_helper and 1 to convert to series to setitem
# 1 extra query to convert to series to setitem
elif key_type.startswith("index"):
query_count = 3
query_count = 2
with SqlCounter(query_count=query_count, join_count=expected_join_count):
eval_snowpark_pandas_result(
pd.Series(series), series, loc_set_helper, inplace=True
Expand Down Expand Up @@ -834,11 +834,7 @@ def loc_set_helper(s):
else:
s.loc[pd.Series(row_key)] = pd.DataFrame(item)

qc = 0
if key_type == "index":
qc = 1

with SqlCounter(query_count=qc):
with SqlCounter(query_count=0):
eval_snowpark_pandas_result(
pd.Series(series),
series,
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,10 @@ def test_concat_keys_with_none(df1, df2, axis):
)
def test_concat_with_keys_and_names(df1, df2, names, name1, name2, axis):
# One extra query to convert index to native pandas when creating df
with SqlCounter(query_count=0 if name1 is None or axis == 1 else 4, join_count=0):
with SqlCounter(query_count=0 if name1 is None or axis == 1 else 3, join_count=0):
df1 = df1.rename_axis(name1, axis=axis)
# One extra query to convert index to native pandas when creating df
with SqlCounter(query_count=0 if name2 is None or axis == 1 else 4, join_count=0):
with SqlCounter(query_count=0 if name2 is None or axis == 1 else 3, join_count=0):
df2 = df2.rename_axis(name2, axis=axis)

expected_join_count = (
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/tools/test_to_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def test_to_datetime_mixed_datetime_and_string(self):
pytest.param("US/Central"),
],
)
@sql_count_checker(query_count=3)
@sql_count_checker(query_count=2)
def test_to_datetime_dtarr(self, tz):
# DatetimeArray
dti = native_pd.date_range("1965-04-03", periods=19, freq="2W", tz=tz)
Expand Down

0 comments on commit 159c81b

Please sign in to comment.