From 159c81bdcc19c1883c878c38c76dd44b9e71d421 Mon Sep 17 00:00:00 2001 From: Naresh Kumar Date: Fri, 9 Aug 2024 12:20:42 -0700 Subject: [PATCH] Fix query counts and ctor updates --- .../snowpark/modin/pandas/general.py | 7 ++-- src/snowflake/snowpark/modin/pandas/series.py | 1 + .../snowpark/modin/plugin/extensions/index.py | 2 +- tests/integ/modin/frame/test_loc.py | 39 +++++-------------- tests/integ/modin/frame/test_set_index.py | 6 +-- tests/integ/modin/series/test_loc.py | 24 +++++------- tests/integ/modin/test_concat.py | 4 +- tests/integ/modin/tools/test_to_datetime.py | 2 +- 8 files changed, 28 insertions(+), 57 deletions(-) diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index 4161d316b0e..af0369771bf 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -1352,7 +1352,7 @@ def to_datetime( infer_datetime_format: lib.NoDefault | bool = lib.no_default, origin: Any = "unix", cache: bool = True, -) -> Series | DatetimeScalar | NaTType | None: +) -> pd.DatetimeIndex | Series | DatetimeScalar | NaTType | None: """ Convert argument to datetime. @@ -1459,8 +1459,7 @@ def to_datetime( parsing): - scalar: :class:`Timestamp` (or :class:`datetime.datetime`) - - array-like: :class:`~snowflake.snowpark.modin.pandas.Series` with :class:`datetime64` dtype containing - :class:`datetime.datetime` (or + - array-like: :class:`~snowflake.snowpark.modin.pandas.DatetimeIndex` (or :class: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`object` dtype containing :class:`datetime.datetime`) - Series: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`datetime64` dtype (or @@ -2170,7 +2169,7 @@ def date_range( qc = qc.set_index_from_columns(qc.columns.tolist(), include_index=False) # Set index column name. qc = qc.set_index_names([name]) - return pd.DatetimeIndex(data=qc) + return pd.DatetimeIndex(query_compiler=qc) @snowpark_pandas_telemetry_standalone_function_decorator diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index a494b513de5..f268a21306b 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -133,6 +133,7 @@ def __init__( # Convert lazy index to Series without pulling the data to client. if isinstance(data, pd.Index): query_compiler = data.to_series(index=index, name=name)._query_compiler + query_compiler = query_compiler.reset_index(drop=True) elif isinstance(data, type(self)): query_compiler = data._query_compiler.copy() if index is not None: diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index e11ac325f0d..a3b4265708a 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -2429,4 +2429,4 @@ def _to_datetime( origin, include_index=True, ) - return DatetimeIndex(data=new_qc) + return DatetimeIndex(query_compiler=new_qc) diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py index f258f261b51..1012a0d3959 100644 --- a/tests/integ/modin/frame/test_loc.py +++ b/tests/integ/modin/frame/test_loc.py @@ -146,7 +146,7 @@ def test_df_loc_get_tuple_key( snow_row = row query_count = 1 - if is_scalar(row) or isinstance(row, tuple) or isinstance(row, native_pd.Index): + if is_scalar(row) or isinstance(row, tuple): query_count = 2 with SqlCounter( @@ -945,11 +945,7 @@ def loc_set_helper(df): _row_key = key_converter(row_key, df) df.loc[_row_key] = pd.DataFrame(item) - with SqlCounter( - # one extra query to convert to series to set item - query_count=2 if key_type == "index" else 1, - join_count=expected_join_count, - ): + with SqlCounter(query_count=1, join_count=expected_join_count): eval_snowpark_pandas_result( pd.DataFrame(native_df), native_df, loc_set_helper, inplace=True ) @@ -971,11 +967,7 @@ def loc_set_helper(df): _row_key = key_converter(row_key, df) df.loc[_row_key, :] = pd.DataFrame(item) - with SqlCounter( - # one extra query to convert to series to set item - query_count=2 if key_type == "index" else 1, - join_count=expected_join_count, - ): + with SqlCounter(query_count=1, join_count=expected_join_count): eval_snowpark_pandas_result( pd.DataFrame(native_df), native_df, loc_set_helper, inplace=True ) @@ -1153,9 +1145,6 @@ def loc_set_helper(df): query_count, join_count = 1, 2 if not all(isinstance(rk_val, bool) for rk_val in row_key): join_count += 2 - # one extra query to convert to native pandas to initialize series and set item - if key_type == "index": - query_count = 2 if isinstance(col_key, native_pd.Series): query_count += 1 with SqlCounter(query_count=query_count, join_count=join_count): @@ -1235,10 +1224,6 @@ def loc_set_helper(df): if isinstance(col_key, native_pd.Series): query_count += 1 - # one extra query to convert to native pandas to initialize series and set item - if key_type == "index": - query_count += 1 - with SqlCounter( query_count=query_count, join_count=join_count, @@ -1316,8 +1301,7 @@ def loc_set_helper(df): else: df.loc[row_key, :] = pd.DataFrame(item) - # one extra query to convert index to native pandas to initialize series and set item - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=4): + with SqlCounter(query_count=1, join_count=4): if item.index.has_duplicates: # pandas fails to update duplicated rows with duplicated item with pytest.raises( @@ -1641,8 +1625,7 @@ def loc_helper(df): return _df.loc[_key] - # one extra query to convert index to native pandas to initialize series and set item - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( default_index_snowpark_pandas_df, default_index_native_df, @@ -1985,8 +1968,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None): ) # default index - # one extra query to convert to series to set item - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( default_index_snowpark_pandas_df, default_index_native_df, @@ -2000,8 +1982,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None): "index" ) non_default_index_snowpark_pandas_df = pd.DataFrame(non_default_index_native_df) - # one extra query to convert to series to set item - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( non_default_index_snowpark_pandas_df, non_default_index_native_df, @@ -2021,8 +2002,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None): ] ) dup_snowpandas_df = pd.DataFrame(dup_native_df) - # one extra query to convert to series to set item - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( dup_snowpandas_df, dup_native_df, @@ -2047,8 +2027,7 @@ def loc_key_type_convert(key, is_snow_type, index_name=None): ] ) dup_snowpandas_df = pd.DataFrame(dup_native_df) - # one extra query to convert to series to set item - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( dup_snowpandas_df, dup_native_df, diff --git a/tests/integ/modin/frame/test_set_index.py b/tests/integ/modin/frame/test_set_index.py index 15566d630f1..e0088673282 100644 --- a/tests/integ/modin/frame/test_set_index.py +++ b/tests/integ/modin/frame/test_set_index.py @@ -320,11 +320,7 @@ def test_set_index_pass_arrays_duplicate(obj_type1, obj_type2, drop, append, nat obj_type2 = native_pd.Index native_keys = [obj_type1(array), obj_type2(array)] - query_count = 4 - # one extra query per modin index to create the series and set index - query_count += 1 if obj_type1 == native_pd.Index else 0 - query_count += 1 if obj_type2 == native_pd.Index else 0 - with SqlCounter(query_count=query_count, join_count=2): + with SqlCounter(query_count=4, join_count=2): eval_snowpark_pandas_result( snow_df, native_df, diff --git a/tests/integ/modin/series/test_loc.py b/tests/integ/modin/series/test_loc.py index 32c1bf64c4a..21fbf6aeafa 100644 --- a/tests/integ/modin/series/test_loc.py +++ b/tests/integ/modin/series/test_loc.py @@ -319,7 +319,7 @@ def loc_helper(ser): return _ser.loc[_key] default_index_series = pd.Series(default_index_native_series) - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( default_index_series, default_index_native_series, @@ -480,7 +480,7 @@ def type_convert(key, is_snow_type): # Note: here number of queries are 2 due to the data type of the series is variant and to_pandas needs to call # typeof to get the value types # TODO: SNOW-933782 optimize to_pandas for variant columns to only fire one query - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( default_index_snowpark_pandas_series, default_index_native_series, @@ -497,7 +497,7 @@ def type_convert(key, is_snow_type): non_default_index_snowpark_pandas_series = pd.Series( non_default_index_native_series ) - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( non_default_index_snowpark_pandas_series, non_default_index_native_series, @@ -514,7 +514,7 @@ def type_convert(key, is_snow_type): ] ) dup_snowpandas_series = pd.Series(dup_native_series) - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( dup_snowpandas_series, dup_native_series, @@ -539,7 +539,7 @@ def type_convert(key, is_snow_type): ] ) dup_snowpandas_series = pd.Series(dup_native_series) - with SqlCounter(query_count=2 if key_type == "index" else 1, join_count=1): + with SqlCounter(query_count=1, join_count=1): eval_snowpark_pandas_result( dup_snowpandas_series, dup_native_series, @@ -776,15 +776,15 @@ def loc_set_helper(s): s.loc[_row_key] = _item query_count = 1 - # 6 extra queries: sum of two cases below + # 5 extra queries: sum of two cases below if item_type.startswith("index") and key_type.startswith("index"): - query_count = 7 + query_count = 6 # 4 extra queries: 1 query to convert item index to pandas in loc_set_helper, 2 for iter, and 1 for to_list elif item_type.startswith("index"): query_count = 5 - # 2 extra queries: 1 query to convert key index to pandas in loc_set_helper and 1 to convert to series to setitem + # 1 extra query to convert to series to setitem elif key_type.startswith("index"): - query_count = 3 + query_count = 2 with SqlCounter(query_count=query_count, join_count=expected_join_count): eval_snowpark_pandas_result( pd.Series(series), series, loc_set_helper, inplace=True @@ -834,11 +834,7 @@ def loc_set_helper(s): else: s.loc[pd.Series(row_key)] = pd.DataFrame(item) - qc = 0 - if key_type == "index": - qc = 1 - - with SqlCounter(query_count=qc): + with SqlCounter(query_count=0): eval_snowpark_pandas_result( pd.Series(series), series, diff --git a/tests/integ/modin/test_concat.py b/tests/integ/modin/test_concat.py index 9437bb6a36c..628af787ac4 100644 --- a/tests/integ/modin/test_concat.py +++ b/tests/integ/modin/test_concat.py @@ -657,10 +657,10 @@ def test_concat_keys_with_none(df1, df2, axis): ) def test_concat_with_keys_and_names(df1, df2, names, name1, name2, axis): # One extra query to convert index to native pandas when creating df - with SqlCounter(query_count=0 if name1 is None or axis == 1 else 4, join_count=0): + with SqlCounter(query_count=0 if name1 is None or axis == 1 else 3, join_count=0): df1 = df1.rename_axis(name1, axis=axis) # One extra query to convert index to native pandas when creating df - with SqlCounter(query_count=0 if name2 is None or axis == 1 else 4, join_count=0): + with SqlCounter(query_count=0 if name2 is None or axis == 1 else 3, join_count=0): df2 = df2.rename_axis(name2, axis=axis) expected_join_count = ( diff --git a/tests/integ/modin/tools/test_to_datetime.py b/tests/integ/modin/tools/test_to_datetime.py index 07fb4aefebf..a0ac55958a9 100644 --- a/tests/integ/modin/tools/test_to_datetime.py +++ b/tests/integ/modin/tools/test_to_datetime.py @@ -570,7 +570,7 @@ def test_to_datetime_mixed_datetime_and_string(self): pytest.param("US/Central"), ], ) - @sql_count_checker(query_count=3) + @sql_count_checker(query_count=2) def test_to_datetime_dtarr(self, tz): # DatetimeArray dti = native_pd.date_range("1965-04-03", periods=19, freq="2W", tz=tz)