From 5f94430a2048255c8716549919c33fdd3492e0c0 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 10:32:54 -0700 Subject: [PATCH 1/3] [CHERRY-PICK][Release-v1.22.0] Cherry pick release change for v1.22.0 (#2269) cherry pick https://github.com/snowflakedb/snowpark-python/pull/2268 --- CHANGELOG.md | 18 +++++++++--------- recipe/meta.yaml | 2 +- src/snowflake/snowpark/version.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fea42391259..56d53a58ad6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,15 @@ # Release History -## 1.22.0 (TBD) +## 1.23.0 (TBD) + + +## 1.22.0 (2024-09-10) ### Snowpark Python API Updates ### New Features -- Added following new functions in `snowflake.snowpark.functions`: +- Added the following new functions in `snowflake.snowpark.functions`: - `array_remove` - `ln` @@ -46,14 +49,14 @@ - Fixed a bug in `session.read.csv` that caused an error when setting `PARSE_HEADER = True` in an externally defined file format. - Fixed a bug in query generation from set operations that allowed generation of duplicate queries when children have common subqueries. - Fixed a bug in `session.get_session_stage` that referenced a non-existing stage after switching database or schema. -- Fixed a bug where calling `DataFrame.to_snowpark_pandas_dataframe` without explicitly initializing the Snowpark pandas plugin caused an error. +- Fixed a bug where calling `DataFrame.to_snowpark_pandas` without explicitly initializing the Snowpark pandas plugin caused an error. - Fixed a bug where using the `explode` function in dynamic table creation caused a SQL compilation error due to improper boolean type casting on the `outer` parameter. ### Snowpark Local Testing Updates #### New Features -- Added support for type coercion when passing columns as input to udf calls +- Added support for type coercion when passing columns as input to UDF calls. - Added support for `Index.identical`. #### Bug Fixes @@ -113,9 +116,10 @@ - Improved `pd.to_datetime` to handle all local input cases. - Create a lazy index from another lazy index without pulling data to client. - Raised `NotImplementedError` for Index bitwise operators. -- Display a clearer error message when `Index.names` is set to a non-like-like object. +- Display a more clear error message when `Index.names` is set to a non-like-like object. - Raise a warning whenever MultiIndex values are pulled in locally. - Improve warning message for `pd.read_snowflake` include the creation reason when temp table creation is triggered. +- Improve performance for `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index` by avoiding checks require eager evaluation. As a consequence, when the new index that does not match the current `Series`/`DataFrame` object length, a `ValueError` is no longer raised. Instead, when the `Series`/`DataFrame` object is longer than the provided index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. Otherwise, the extra values in the provided index are ignored. #### Bug Fixes @@ -126,10 +130,6 @@ - Fixed a bug where `Series.reindex` and `DataFrame.reindex` did not update the result index's name correctly. - Fixed a bug where `Series.take` did not error when `axis=1` was specified. -#### Behavior Change - -- When calling `DataFrame.set_index`, or setting `DataFrame.index` or `Series.index`, with a new index that does not match the current length of the `Series`/`DataFrame` object, a `ValueError` is no longer raised. When the `Series`/`DataFrame` object is longer than the new index, the `Series`/`DataFrame`'s new index is filled with `NaN` values for the "extra" elements. When the `Series`/`DataFrame` object is shorter than the new index, the extra values in the new index are ignored—`Series` and `DataFrame` stay the same length `n`, and use only the first `n` values of the new index. - ## 1.21.1 (2024-09-05) diff --git a/recipe/meta.yaml b/recipe/meta.yaml index cf1f2c9ad70..9aed0375d0c 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -1,5 +1,5 @@ {% set name = "snowflake-snowpark-python" %} -{% set version = "1.21.1" %} +{% set version = "1.22.0" %} package: name: {{ name|lower }} diff --git a/src/snowflake/snowpark/version.py b/src/snowflake/snowpark/version.py index 3955dbbbf33..4b7ad25b189 100644 --- a/src/snowflake/snowpark/version.py +++ b/src/snowflake/snowpark/version.py @@ -4,4 +4,4 @@ # # Update this for the versions -VERSION = (1, 21, 1) +VERSION = (1, 22, 0) From 195226e5c640646bcb6c9dffd5e7b9bbd02650d1 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Wed, 11 Sep 2024 10:50:24 -0700 Subject: [PATCH 2/3] SNOW-1548942 Implement support for dataframe apply axis=0 (#2241) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. SNOW-1548942 Implement support for dataframe apply axis=0 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. This pr adds support for df.apply with axis=0 --- CHANGELOG.md | 1 + .../modin/supported/dataframe_supported.rst | 5 +- .../modin/plugin/_internal/apply_utils.py | 32 +- .../modin/plugin/_internal/join_utils.py | 4 +- .../compiler/snowflake_query_compiler.py | 335 +++++++-- .../modin/plugin/docstrings/dataframe.py | 24 +- tests/integ/modin/frame/test_apply.py | 89 ++- tests/integ/modin/frame/test_apply_axis_0.py | 653 ++++++++++++++++++ 8 files changed, 1020 insertions(+), 123 deletions(-) create mode 100644 tests/integ/modin/frame/test_apply_axis_0.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 56d53a58ad6..b119e75573a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -108,6 +108,7 @@ - Added support for creating a `DatetimeIndex` from an `Index` of numeric or string type. - Added support for string indexing with `Timedelta` objects. - Added support for `Series.dt.total_seconds` method. +- Added support for `DataFrame.apply(axis=0)`. #### Improvements diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 6bb214e3bd6..54858063e54 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -84,7 +84,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``any`` | P | | ``N`` for non-integer/boolean types | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``apply`` | P | | ``N`` if ``axis == 0`` or ``func`` is not callable | +| ``apply`` | P | | ``N`` if ``func`` is not callable | | | | | or ``result_type`` is given or ``args`` and | | | | | ``kwargs`` contain DataFrame or Series | | | | | ``N`` if ``func`` maps to different column labels. | @@ -471,8 +471,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``to_xml`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``transform`` | P | | Only callable and string parameters are supported.| -| | | | list and dict parameters are not supported. | +| ``transform`` | P | | ``Y`` if ``func`` is callable. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``transpose`` | P | | See ``T`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index b58ba4f50ea..f87cdcd2e47 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -81,7 +81,7 @@ class GroupbyApplySortMethod(Enum): def check_return_variant_and_get_return_type(func: Callable) -> tuple[bool, DataType]: """Check whether the function returns a variant in Snowflake, and get its return type.""" - return_type, _ = get_types_from_type_hints(func, TempObjectType.FUNCTION) + return_type = deduce_return_type_from_function(func) if return_type is None or isinstance( return_type, (VariantType, PandasSeriesType, PandasDataFrameType) ): @@ -390,6 +390,7 @@ def create_udtf_for_groupby_apply( series_groupby: bool, by_types: list[DataType], existing_identifiers: list[str], + force_list_like_to_series: bool = False, ) -> UserDefinedTableFunction: """ Create a UDTF from the Python function for groupby.apply. @@ -480,6 +481,7 @@ def create_udtf_for_groupby_apply( series_groupby: Whether we are performing a SeriesGroupBy.apply() instead of DataFrameGroupBy.apply() by_types: The snowflake types of the by columns. existing_identifiers: List of existing column identifiers; these are omitted when creating new column identifiers. + force_list_like_to_series: Force the function result to series if it is list-like Returns ------- @@ -553,6 +555,17 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def # https://github.com/snowflakedb/snowpandas/pull/823/files#r1507286892 input_object = input_object.infer_objects() func_result = func(input_object, *args, **kwargs) + if ( + force_list_like_to_series + and not isinstance(func_result, native_pd.Series) + and native_pd.api.types.is_list_like(func_result) + ): + if len(func_result) == 1: + func_result = func_result[0] + else: + func_result = native_pd.Series(func_result) + if len(func_result) == len(df.index): + func_result.index = df.index if isinstance(func_result, native_pd.Series): if series_groupby: func_result_as_frame = func_result.to_frame() @@ -754,7 +767,7 @@ def __init__(self) -> None: def convert_numpy_int_result_to_int(value: Any) -> Any: """ - If the result is a numpy int, convert it to a python int. + If the result is a numpy int (or bool), convert it to a python int (or bool.) Use this function to make UDF results JSON-serializable. numpy ints are not JSON-serializable, but python ints are. Note that this function cannot make @@ -772,9 +785,14 @@ def convert_numpy_int_result_to_int(value: Any) -> Any: Returns ------- - int(value) if the value is a numpy int, otherwise the value. + int(value) if the value is a numpy int, + bool(value) if the value is a numpy bool, otherwise the value. """ - return int(value) if np.issubdtype(type(value), np.integer) else value + return ( + int(value) + if np.issubdtype(type(value), np.integer) + else (bool(value) if np.issubdtype(type(value), np.bool_) else value) + ) def deduce_return_type_from_function( @@ -887,7 +905,7 @@ def get_metadata_from_groupby_apply_pivot_result_column_names( input: get_metadata_from_groupby_apply_pivot_result_column_names([ - # this representa a data column named ('a', 'group_key') at position 0 + # this represents a data column named ('a', 'group_key') at position 0 '"\'{""0"": ""a"", ""1"": ""group_key"", ""data_pos"": 0, ""names"": [""c1"", ""c2""]}\'"', # this represents a data column named ('b', 'int_col') at position 1 '"\'{""0"": ""b"", ""1"": ""int_col"", ""data_pos"": 1, ""names"": [""c1"", ""c2""]}\'"', @@ -1110,7 +1128,9 @@ def groupby_apply_pivot_result_to_final_ordered_dataframe( # in GROUP_KEY_APPEARANCE_ORDER) and assign the # label i to all rows that came from func(group_i). [ - original_row_position_snowflake_quoted_identifier + col(original_row_position_snowflake_quoted_identifier).as_( + new_index_identifier + ) if sort_method is GroupbyApplySortMethod.ORIGINAL_ROW_ORDER else ( dense_rank().over( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 457bd388f2b..79f063b9ece 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -1075,7 +1075,7 @@ def join_on_index_columns( Returns: An InternalFrame for the joined result. - A JoinOrAlignResultColumnMapper that provides quited identifiers mapping from the + A JoinOrAlignResultColumnMapper that provides quoted identifiers mapping from the original left and right dataframe to the joined dataframe, it is guaranteed to include mapping for index + data columns, ordering columns and row position column if exists. @@ -1263,7 +1263,7 @@ def align_on_index( * outer: use union of index from both frames, sort index lexicographically. Returns: An InternalFrame for the aligned result. - A JoinOrAlignResultColumnMapper that provides quited identifiers mapping from the + A JoinOrAlignResultColumnMapper that provides quoted identifiers mapping from the original left and right dataframe to the aligned dataframe, it is guaranteed to include mapping for index + data columns, ordering columns and row position column if exists. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 400e98562f9..f5c6be3b751 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -172,6 +172,7 @@ APPLY_LABEL_COLUMN_QUOTED_IDENTIFIER, APPLY_VALUE_COLUMN_QUOTED_IDENTIFIER, DEFAULT_UDTF_PARTITION_SIZE, + GroupbyApplySortMethod, check_return_variant_and_get_return_type, create_udf_for_series_apply, create_udtf_for_apply_axis_1, @@ -3757,6 +3758,8 @@ def groupby_apply( agg_args: Any, agg_kwargs: dict[str, Any], series_groupby: bool, + force_single_group: bool = False, + force_list_like_to_series: bool = False, ) -> "SnowflakeQueryCompiler": """ Group according to `by` and `level`, apply a function to each group, and combine the results. @@ -3777,6 +3780,10 @@ def groupby_apply( Keyword arguments to pass to agg_func when applying it to each group. series_groupby: Whether we are performing a SeriesGroupBy.apply() instead of a DataFrameGroupBy.apply() + force_single_group: + Force single group (empty set of group by labels) useful for DataFrame.apply() with axis=0 + force_list_like_to_series: + Force the function result to series if it is list-like Returns ------- @@ -3804,15 +3811,23 @@ def groupby_apply( dropna = groupby_kwargs.get("dropna", True) group_keys = groupby_kwargs.get("group_keys", False) - by_pandas_labels = extract_groupby_column_pandas_labels(self, by, level) + by_pandas_labels = ( + [] + if force_single_group + else extract_groupby_column_pandas_labels(self, by, level) + ) - by_snowflake_quoted_identifiers_list = [ - quoted_identifier - for entry in self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - by_pandas_labels - ) - for quoted_identifier in entry - ] + by_snowflake_quoted_identifiers_list = ( + [] + if force_single_group + else [ + quoted_identifier + for entry in self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + by_pandas_labels + ) + for quoted_identifier in entry + ] + ) snowflake_type_map = self._modin_frame.quoted_identifier_to_snowflake_type() @@ -3846,11 +3861,14 @@ def groupby_apply( ], session=self._modin_frame.ordered_dataframe.session, series_groupby=series_groupby, - by_types=[ + by_types=[] + if force_single_group + else [ snowflake_type_map[quoted_identifier] for quoted_identifier in by_snowflake_quoted_identifiers_list ], existing_identifiers=self._modin_frame.ordered_dataframe._dataframe_ref.snowflake_quoted_identifiers, + force_list_like_to_series=force_list_like_to_series, ) new_internal_df = self._modin_frame.ensure_row_position_column() @@ -3922,9 +3940,9 @@ def groupby_apply( *new_internal_df.index_column_snowflake_quoted_identifiers, *input_data_column_identifiers, ).over( - partition_by=[ - *by_snowflake_quoted_identifiers_list, - ], + partition_by=None + if force_single_group + else [*by_snowflake_quoted_identifiers_list], order_by=row_position_snowflake_quoted_identifier, ), ) @@ -4066,7 +4084,9 @@ def groupby_apply( ordered_dataframe=ordered_dataframe, agg_func=agg_func, by_snowflake_quoted_identifiers_list=by_snowflake_quoted_identifiers_list, - sort_method=groupby_apply_sort_method( + sort_method=GroupbyApplySortMethod.ORIGINAL_ROW_ORDER + if force_single_group + else groupby_apply_sort_method( sort, group_keys, original_row_position_snowflake_quoted_identifier, @@ -7888,11 +7908,6 @@ def apply( """ self._raise_not_implemented_error_for_timedelta() - # axis=0 is not supported, raise error. - if axis == 0: - ErrorMessage.not_implemented( - "Snowpark pandas apply API doesn't yet support axis == 0" - ) # Only callables are supported for axis=1 mode for now. if not callable(func) and not isinstance(func, UserDefinedFunction): ErrorMessage.not_implemented( @@ -7909,56 +7924,260 @@ def apply( "Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'" ) - # get input types of all data columns from the dataframe directly - input_types = self._modin_frame.get_snowflake_type( - self._modin_frame.data_column_snowflake_quoted_identifiers - ) + if axis == 0: + frame = self._modin_frame - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native + # To apply function to Dataframe with axis=0, we repurpose the groupby apply function by taking each + # column, as a series, and treat as a single group to apply function. Then collect the column results to + # join together for the final result. + col_results = [] - # current columns - column_index = try_convert_index_to_native(self._modin_frame.data_columns_index) + # If raw, then pass numpy ndarray rather than pandas Series as input to the apply function. + if raw: - # Extract return type from annotations (or lookup for known pandas functions) for func object, - # if not return type could be extracted the variable will hold None. - return_type = deduce_return_type_from_function(func) + def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no cover: adding type hint causes an error when creating udtf. also, skip coverage for this function because coverage tools can't tell that we're executing this function because we execute it in a UDTF. + raw_input_obj = args[0].to_numpy() + args = (raw_input_obj,) + args[1:] + return func(*args, **kwargs) - # Check whether return_type has been extracted. If return type is not - # a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to - # be performed which means that the result of df.apply(axis=1) is always a Series object. - if return_type and not ( - isinstance(return_type, PandasSeriesType) - or isinstance(return_type, ArrayType) - ): - return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1( - func, - column_index, - input_types, - return_type, - udf_args=args, - udf_kwargs=kwargs, - session=self._modin_frame.ordered_dataframe.session, - ) + agg_func = wrapped_func + else: + agg_func = func + + # Accumulate indices of the column results. + col_result_indexes = [] + # Accumulate "is scalar" flags for the column results. + col_result_scalars = [] + + # Loop through each data column of the original df frame + for (column_index, data_column_pair) in enumerate( + zip( + frame.data_column_pandas_labels, + frame.data_column_snowflake_quoted_identifiers, + ) + ): + ( + data_column_pandas_label, + data_column_snowflake_quoted_identifier, + ) = data_column_pair + + # Create a frame for the current data column which we will be passed to the apply function below. + # Note that we maintain the original index because the apply function may access via the index. + data_col_qc = self.take_2d_positional( + index=slice(None, None), columns=[column_index] + ) + + data_col_frame = data_col_qc._modin_frame + + data_col_qc = data_col_qc.groupby_apply( + by=[], + agg_func=agg_func, + axis=0, + groupby_kwargs={"as_index": False, "dropna": False}, + agg_args=args, + agg_kwargs=kwargs, + series_groupby=True, + force_single_group=True, + force_list_like_to_series=True, + ) + + data_col_result_frame = data_col_qc._modin_frame + + # Set the index names and corresponding data column pandas label on the result. + data_col_result_frame = InternalFrame.create( + ordered_dataframe=data_col_result_frame.ordered_dataframe, + data_column_snowflake_quoted_identifiers=data_col_result_frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_labels=[data_column_pandas_label], + data_column_pandas_index_names=data_col_frame.data_column_pandas_index_names, + data_column_types=None, + index_column_snowflake_quoted_identifiers=data_col_result_frame.index_column_snowflake_quoted_identifiers, + index_column_pandas_labels=data_col_result_frame.index_column_pandas_labels, + index_column_types=data_col_result_frame.cached_index_column_snowpark_pandas_types, + ) + + data_col_result_index = ( + data_col_result_frame.index_columns_pandas_index() + ) + col_result_indexes.append(data_col_result_index) + # TODO: For functions like np.sum, when supported, we can know upfront the result is a scalar + # so don't need to look at the index. + col_result_scalars.append( + len(data_col_result_index) == 1 and data_col_result_index[0] == -1 + ) + col_results.append(SnowflakeQueryCompiler(data_col_result_frame)) + + result_is_series = False + + if len(col_results) == 1: + result_is_series = col_result_scalars[0] + qc_result = col_results[0] + + # Squeeze to series if it is single column + qc_result = qc_result.columnarize() + if col_result_scalars[0]: + qc_result = qc_result.reset_index(drop=True) + else: + single_row_output = all(len(index) == 1 for index in col_result_indexes) + if single_row_output: + all_scalar_output = all( + is_scalar for is_scalar in col_result_scalars + ) + if all_scalar_output: + # If the apply function maps all columns to a scalar value, then we need to join them together + # to return as a Series result. + + # Ensure all column results have the same column name so concat will be aligned. + for i, qc in enumerate(col_results): + col_results[i] = qc.set_columns([0]) + + qc_result = col_results[0].concat( + axis=0, + other=col_results[1:], + keys=frame.data_column_pandas_labels, + ) + qc_frame = qc_result._modin_frame + + # Drop the extraneous index column from the original result series. + qc_result = SnowflakeQueryCompiler( + InternalFrame.create( + ordered_dataframe=qc_frame.ordered_dataframe, + data_column_snowflake_quoted_identifiers=qc_frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_labels=qc_frame.data_column_pandas_labels, + data_column_pandas_index_names=qc_frame.data_column_pandas_index_names, + data_column_types=qc_frame.cached_data_column_snowpark_pandas_types, + index_column_snowflake_quoted_identifiers=qc_frame.index_column_snowflake_quoted_identifiers[ + :-1 + ], + index_column_pandas_labels=qc_frame.index_column_pandas_labels[ + :-1 + ], + index_column_types=qc_frame.cached_index_column_snowpark_pandas_types[ + :-1 + ], + ) + ) + + result_is_series = True + else: + no_scalar_output = all( + not is_scalar for is_scalar in col_result_scalars + ) + if no_scalar_output: + # Output is Dataframe + all_same_index = col_result_indexes.count( + col_result_indexes[0] + ) == len(col_result_indexes) + qc_result = col_results[0].concat( + axis=1, other=col_results[1:], sort=not all_same_index + ) + else: + # If there's a mix of scalar and pd.Series output from the apply func, pandas stores the + # pd.Series output as the value, which we do not currently support. + ErrorMessage.not_implemented( + "Nested pd.Series in result is not supported in DataFrame.apply(axis=0)" + ) + else: + if any(is_scalar for is_scalar in col_result_scalars): + # If there's a mix of scalar and pd.Series output from the apply func, pandas stores the + # pd.Series output as the value, which we do not currently support. + ErrorMessage.not_implemented( + "Nested pd.Series in result is not supported in DataFrame.apply(axis=0)" + ) + + duplicate_index_values = not all( + len(i) == len(set(i)) for i in col_result_indexes + ) + + # If there are duplicate index values then align on the index for matching results with Pandas. + if duplicate_index_values: + curr_frame = col_results[0]._modin_frame + for next_qc in col_results[1:]: + curr_frame = join_utils.align( + curr_frame, next_qc._modin_frame, [], [], how="left" + ).result_frame + qc_result = SnowflakeQueryCompiler(curr_frame) + else: + # If there are multiple output series with different indices, then line them up as a series output. + all_same_index = all( + all(i == col_result_indexes[0]) for i in col_result_indexes + ) + # If the col results all have same index then we keep the existing index ordering. + qc_result = col_results[0].concat( + axis=1, other=col_results[1:], sort=not all_same_index + ) + + # If result should be Series then change the data column label appropriately. + if result_is_series: + qc_result_frame = qc_result._modin_frame + qc_result = SnowflakeQueryCompiler( + InternalFrame.create( + ordered_dataframe=qc_result_frame.ordered_dataframe, + data_column_snowflake_quoted_identifiers=qc_result_frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_labels=[MODIN_UNNAMED_SERIES_LABEL], + data_column_pandas_index_names=qc_result_frame.data_column_pandas_index_names, + data_column_types=qc_result_frame.cached_data_column_snowpark_pandas_types, + index_column_snowflake_quoted_identifiers=qc_result_frame.index_column_snowflake_quoted_identifiers, + index_column_pandas_labels=qc_result_frame.index_column_pandas_labels, + index_column_types=qc_result_frame.cached_index_column_snowpark_pandas_types, + ) + ) + + return qc_result else: - # Issue actionable warning for users to consider annotating UDF with type annotations - # for better performance. - function_name = ( - func.__name__ if isinstance(func, Callable) else str(func) # type: ignore[arg-type] + # get input types of all data columns from the dataframe directly + input_types = self._modin_frame.get_snowflake_type( + self._modin_frame.data_column_snowflake_quoted_identifiers ) - WarningMessage.single_warning( - f"Function {function_name} passed to apply does not have type annotations," - f" or Snowpark pandas could not extract type annotations. Executing apply" - f" in slow code path which may result in decreased performance. " - f"To disable this warning and improve performance, consider annotating" - f" {function_name} with type annotations." + + from snowflake.snowpark.modin.pandas.utils import ( + try_convert_index_to_native, ) - # Result may need to get expanded into multiple columns, or return type of func is not known. - # Process using UDTF together with dynamic pivot for either case. - return self._apply_with_udtf_and_dynamic_pivot_along_axis_1( - func, raw, result_type, args, column_index, input_types, **kwargs + # current columns + column_index = try_convert_index_to_native( + self._modin_frame.data_columns_index ) + # Extract return type from annotations (or lookup for known pandas functions) for func object, + # if not return type could be extracted the variable will hold None. + return_type = deduce_return_type_from_function(func) + + # Check whether return_type has been extracted. If return type is not + # a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to + # be performed which means that the result of df.apply(axis=1) is always a Series object. + if return_type and not ( + isinstance(return_type, PandasSeriesType) + or isinstance(return_type, ArrayType) + ): + return self._apply_udf_row_wise_and_reduce_to_series_along_axis_1( + func, + column_index, + input_types, + return_type, + udf_args=args, + udf_kwargs=kwargs, + session=self._modin_frame.ordered_dataframe.session, + ) + else: + # Issue actionable warning for users to consider annotating UDF with type annotations + # for better performance. + function_name = ( + func.__name__ if isinstance(func, Callable) else str(func) # type: ignore[arg-type] + ) + WarningMessage.single_warning( + f"Function {function_name} passed to apply does not have type annotations," + f" or Snowpark pandas could not extract type annotations. Executing apply" + f" in slow code path which may result in decreased performance. " + f"To disable this warning and improve performance, consider annotating" + f" {function_name} with type annotations." + ) + + # Result may need to get expanded into multiple columns, or return type of func is not known. + # Process using UDTF together with dynamic pivot for either case. + return self._apply_with_udtf_and_dynamic_pivot_along_axis_1( + func, raw, result_type, args, column_index, input_types, **kwargs + ) + def applymap( self, func: AggFuncType, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 6d79d07ab84..6223e9dd273 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -730,7 +730,7 @@ def apply(): Parameters ---------- func : function - A Python function object to apply to each column or row, or a Python function decorated with @udf. + A Python function object to apply to each column or row. axis : {0 or 'index', 1 or 'columns'}, default 0 Axis along which the function is applied: @@ -738,8 +738,6 @@ def apply(): * 0 or 'index': apply function to each column. * 1 or 'columns': apply function to each row. - Snowpark pandas does not yet support ``axis=0``. - raw : bool, default False Determines if row or column is passed as a Series or ndarray object: @@ -810,8 +808,6 @@ def apply(): 7. When ``func`` uses any first-party modules or third-party packages inside the function, you need to add these dependencies via ``session.add_import()`` and ``session.add_packages()``. - Alternatively. specify third-party packages with the @udf decorator. When using the @udf decorator, - annotations using PandasSeriesType or PandasDataFrameType are not supported. 8. The Snowpark pandas module cannot currently be referenced inside the definition of ``func``. If you need to call a general pandas API like ``pd.Timestamp`` inside ``func``, @@ -852,22 +848,6 @@ def apply(): 1 14.50 2 24.25 dtype: float64 - - or annotate the function - with the @udf decorator from Snowpark https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/api/snowflake.snowpark.functions.udf. - - >>> from snowflake.snowpark.functions import udf - >>> from snowflake.snowpark.types import DoubleType - >>> @udf(packages=['statsmodels>0.12'], return_type=DoubleType()) - ... def autocorr(column): - ... import pandas as pd - ... import statsmodels.tsa.stattools - ... return pd.Series(statsmodels.tsa.stattools.pacf_ols(column.values)).mean() - ... - >>> df.apply(autocorr, axis=0) # doctest: +SKIP - A 0.857143 - B 0.428571 - dtype: float64 """ def assign(): @@ -1061,8 +1041,6 @@ def transform(): axis : {0 or 'index', 1 or 'columns'}, default 0 If 0 or 'index': apply function to each column. If 1 or 'columns': apply function to each row. - Snowpark pandas currently only supports axis=1, and does not yet support axis=0. - *args Positional arguments to pass to `func`. diff --git a/tests/integ/modin/frame/test_apply.py b/tests/integ/modin/frame/test_apply.py index 1014cae44c9..ded0651046c 100644 --- a/tests/integ/modin/frame/test_apply.py +++ b/tests/integ/modin/frame/test_apply.py @@ -337,16 +337,6 @@ def f(x, y, z=1) -> int: class TestNotImplemented: - @pytest.mark.parametrize( - "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP - ) - @sql_count_checker(query_count=0) - def test_axis_0(self, data, func, return_type): - snow_df = pd.DataFrame(data) - msg = "Snowpark pandas apply API doesn't yet support axis == 0" - with pytest.raises(NotImplementedError, match=msg): - snow_df.apply(func) - @pytest.mark.parametrize("result_type", ["reduce", "expand", "broadcast"]) @sql_count_checker(query_count=0) def test_result_type(self, result_type): @@ -554,33 +544,70 @@ def g(v): ] -TRANSFORM_DATA_FUNC_MAP = [ - [[[0, 1, 2], [1, 2, 3]], lambda x: x + 1], - [[[0, 1, 2], [1, 2, 3]], np.exp], - [[[0, 1, 2], [1, 2, 3]], "exp"], - [[["Leonhard", "Jianzhun"]], lambda x: x + " is awesome!!"], - [[[1.3, 2.5]], np.sqrt], - [[[1.3, 2.5]], "sqrt"], - [[[1.3, 2.5]], np.log], - [[[1.3, 2.5]], "log"], - [[[1.3, 2.5]], np.square], - [[[1.3, 2.5]], "square"], +@pytest.mark.xfail( + strict=True, + raises=SnowparkSQLException, + reason="SNOW-1650918: Apply on dataframe data columns containing NULL fails with invalid arguments to udtf function", +) +@pytest.mark.parametrize( + "data, apply_func", [ - [[None, "abcd"]], - lambda x: x + " are first 4 letters of alphabet" if x is not None else None, + [ + [[None, "abcd"]], + lambda x: x + " are first 4 letters of alphabet" if x is not None else None, + ], + [ + [[123, None]], + lambda x: x + 100 if x is not None else None, + ], ], - [[[1.5, float("nan")]], lambda x: np.sqrt(x)], +) +def test_apply_bug_1650918(data, apply_func): + native_df = native_pd.DataFrame(data) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(apply_func, axis=1), + ) + + +TRANSFORM_TEST_MAP = [ + [[[0, 1, 2], [1, 2, 3]], lambda x: x + 1, 16], + [[[0, 1, 2], [1, 2, 3]], np.exp, 16], + [[[0, 1, 2], [1, 2, 3]], "exp", None], + [[["Leonhard", "Jianzhun"]], lambda x: x + " is awesome!!", 11], + [[[1.3, 2.5]], np.sqrt, 11], + [[[1.3, 2.5]], "sqrt", None], + [[[1.3, 2.5]], np.log, 11], + [[[1.3, 2.5]], "log", None], + [[[1.3, 2.5]], np.square, 11], + [[[1.3, 2.5]], "square", None], + [[[1.5, float("nan")]], lambda x: np.sqrt(x), 11], ] @pytest.mark.modin_sp_precommit -@pytest.mark.parametrize("data, apply_func", TRANSFORM_DATA_FUNC_MAP) -@sql_count_checker(query_count=0) -def test_basic_dataframe_transform(data, apply_func): - msg = "Snowpark pandas apply API doesn't yet support axis == 0" - with pytest.raises(NotImplementedError, match=msg): +@pytest.mark.parametrize("data, apply_func, expected_query_count", TRANSFORM_TEST_MAP) +def test_basic_dataframe_transform(data, apply_func, expected_query_count): + if expected_query_count is None: + msg = "Snowpark pandas apply API only supports callables func" + with SqlCounter(query_count=0): + with pytest.raises(NotImplementedError, match=msg): + snow_df = pd.DataFrame(data) + snow_df.transform(apply_func) + else: + msg = "SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function" + native_df = native_pd.DataFrame(data) snow_df = pd.DataFrame(data) - snow_df.transform(apply_func) + with SqlCounter( + query_count=expected_query_count, + high_count_expected=True, + high_count_reason=msg, + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.transform(apply_func) + ) AGGREGATION_FUNCTIONS = [ @@ -610,7 +637,7 @@ def test_dataframe_transform_invalid_function_name_negative(session): snow_df = pd.DataFrame([[0, 1, 2], [1, 2, 3]]) with pytest.raises( NotImplementedError, - match="Snowpark pandas apply API doesn't yet support axis == 0", + match="Snowpark pandas apply API only supports callables func", ): snow_df.transform("mxyzptlk") diff --git a/tests/integ/modin/frame/test_apply_axis_0.py b/tests/integ/modin/frame/test_apply_axis_0.py new file mode 100644 index 00000000000..47fd14d7b98 --- /dev/null +++ b/tests/integ/modin/frame/test_apply_axis_0.py @@ -0,0 +1,653 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import datetime + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark.exceptions import SnowparkSQLException +from tests.integ.modin.series.test_apply import create_func_with_return_type_hint +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.utils import ( + assert_snowpark_pandas_equal_to_pandas, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + create_test_dfs, + eval_snowpark_pandas_result, +) + +# test data which has a python type as return type that is not a pandas Series/pandas DataFrame/tuple/list +BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP = [ + [[[1.0, 2.2], [3, np.nan]], np.min, "float"], + [[[1.1, 2.2], [3, np.nan]], lambda x: x.sum(), "float"], + [[[1.1, 2.2], [3, np.nan]], lambda x: x.size, "int"], + [[[1.1, 2.2], [3, np.nan]], lambda x: "0" if x.sum() > 1 else 0, "object"], + [[["snow", "flake"], ["data", "cloud"]], lambda x: x[0] + x[1], "str"], + [[[True, False], [False, False]], lambda x: True, "bool"], + [[[True, False], [False, False]], lambda x: x[0] ^ x[1], "bool"], + ( + [ + [bytes("snow", "utf-8"), bytes("flake", "utf-8")], + [bytes("data", "utf-8"), bytes("cloud", "utf-8")], + ], + lambda x: (x[0] + x[1]).decode(), + "str", + ), + ( + [[["a", "b"], ["c", "d"]], [["a", "b"], ["c", "d"]]], + lambda x: x[0][1] + x[1][0], + "str", + ), + ( + [[{"a": "b"}, {"c": "d"}], [{"c": "b"}, {"a": "d"}]], + lambda x: str(x[0]) + str(x[1]), + "str", + ), +] + + +@pytest.mark.parametrize( + "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP +) +@pytest.mark.modin_sp_precommit +def test_axis_0_basic_types_without_type_hints(data, func, return_type): + # this test processes functions without type hints and invokes the UDTF solution. + native_df = native_pd.DataFrame(data, columns=["A", "b"]) + snow_df = pd.DataFrame(data, columns=["A", "b"]) + with SqlCounter( + query_count=11, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result(snow_df, native_df, lambda x: x.apply(func, axis=0)) + + +@pytest.mark.parametrize( + "data, func, return_type", BASIC_DATA_FUNC_PYTHON_RETURN_TYPE_MAP +) +@pytest.mark.modin_sp_precommit +def test_axis_0_basic_types_with_type_hints(data, func, return_type): + # create explicitly for supported python types UDF with type hints and process via vUDF. + native_df = native_pd.DataFrame(data, columns=["A", "b"]) + snow_df = pd.DataFrame(data, columns=["A", "b"]) + func_with_type_hint = create_func_with_return_type_hint(func, return_type) + # Invoking a single UDF typically requires 3 queries (package management, code upload, UDF registration) upfront. + with SqlCounter( + query_count=11, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(func_with_type_hint, axis=0) + ) + + +@pytest.mark.parametrize( + "df,row_label", + [ + ( + native_pd.DataFrame( + [[1, 2], [None, 3]], columns=["A", "b"], index=["A", "B"] + ), + "B", + ), + ( + native_pd.DataFrame( + [[1, 2], [None, 3]], + columns=["A", "b"], + index=pd.MultiIndex.from_tuples([(1, 2), (1, 1)]), + ), + (1, 2), + ), + ], +) +def test_axis_0_index_passed_as_name(df, row_label): + # when using apply(axis=1) the original index of the dataframe is passed as name. + # test here for this for regular index and multi-index scenario. + + def foo(row) -> str: + if row.name == row_label: + return "MATCHING LABEL" + else: + return "NO MATCH" + + snow_df = pd.DataFrame(df) + with SqlCounter( + query_count=11, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=0)) + + +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_return_series(): + snow_df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"]) + native_df = native_pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"]) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(lambda x: native_pd.Series([1, 2], index=["C", "d"]), axis=0), + ) + + +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_return_series_with_different_label_results(): + df = native_pd.DataFrame([[1, 2], [3, 4]], columns=["A", "b"]) + snow_df = pd.DataFrame(df) + + eval_snowpark_pandas_result( + snow_df, + df, + lambda df: df.apply( + lambda x: native_pd.Series([1, 2], index=["a", "b"]) + if x.sum() > 3 + else native_pd.Series([0, 1, 2], index=["c", "a", "b"]), + axis=0, + ), + ) + + +@sql_count_checker(query_count=6, join_count=1, udtf_count=1) +def test_axis_0_return_single_scalar_series(): + native_df = native_pd.DataFrame([1]) + snow_df = pd.DataFrame(native_df) + + def apply_func(x): + return native_pd.Series([1], index=["xyz"]) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(apply_func, axis=0) + ) + + +@sql_count_checker(query_count=3) +def test_axis_0_return_dataframe_not_supported(): + snow_df = pd.DataFrame([1]) + + # Note that pands returns failure "ValueError: If using all scalar values, you must pass an index" which + # doesn't explain this isn't supported. We go with the default returned by pandas in this case. + with pytest.raises( + SnowparkSQLException, match="The truth value of a DataFrame is ambiguous." + ): + # return value + snow_df.apply(lambda x: native_pd.DataFrame([1, 2]), axis=0).to_pandas() + + +class TestNotImplemented: + @pytest.mark.parametrize("result_type", ["reduce", "expand", "broadcast"]) + @sql_count_checker(query_count=0) + def test_result_type(self, result_type): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + msg = "Snowpark pandas apply API doesn't yet support 'result_type' parameter" + with pytest.raises(NotImplementedError, match=msg): + snow_df.apply(lambda x: [1, 2], axis=0, result_type=result_type) + + @sql_count_checker(query_count=0) + def test_axis_1_apply_args_kwargs_with_snowpandas_object(self): + def f(x, y=None) -> native_pd.Series: + return x + (y if y is not None else 0) + + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + msg = "Snowpark pandas apply API doesn't yet support DataFrame or Series in 'args' or 'kwargs' of 'func'" + with pytest.raises(NotImplementedError, match=msg): + snow_df.apply(f, axis=0, args=(pd.Series([1, 2]),)) + with pytest.raises(NotImplementedError, match=msg): + snow_df.apply(f, axis=0, y=pd.Series([1, 2])) + + +TEST_INDEX_1 = native_pd.MultiIndex.from_tuples( + list(zip(*[["a", "b"], ["x", "y"]])), + names=["first", "last"], +) + + +TEST_INDEX_WITH_NULL_1 = native_pd.MultiIndex.from_tuples( + list(zip(*[[None, "b"], ["x", None]])), + names=["first", "last"], +) + + +TEST_INDEX_2 = native_pd.MultiIndex.from_tuples( + list(zip(*[["AA", "BB"], ["XX", "YY"]])), + names=["FOO", "BAR"], +) + +TEST_INDEX_WITH_NULL_2 = native_pd.MultiIndex.from_tuples( + list(zip(*[[None, "BB"], ["XX", None]])), + names=["FOO", "BAR"], +) + + +TEST_COLUMNS_1 = native_pd.MultiIndex.from_tuples( + list( + zip( + *[ + ["car", "motorcycle", "bike", "bus"], + ["blue", "green", "red", "yellow"], + ] + ) + ), + names=["vehicle", "color"], +) + + +@pytest.mark.parametrize( + "apply_func, expected_join_count, expected_union_count", + [ + [lambda x: [1, 2], 3, 0], + [lambda x: x + 1 if x is not None else None, 3, 0], + [lambda x: x.min(), 2, 1], + ], +) +def test_axis_0_series_basic(apply_func, expected_join_count, expected_union_count): + native_df = native_pd.DataFrame( + [[1.1, 2.2], [3.0, None]], index=pd.Index([2, 3]), columns=["A", "b"] + ) + snow_df = pd.DataFrame(native_df) + with SqlCounter( + query_count=11, + join_count=expected_join_count, + udtf_count=2, + union_count=expected_union_count, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.apply(apply_func, axis=0), + ) + + +@sql_count_checker(query_count=5, join_count=1, udtf_count=1) +def test_groupby_apply_constant_output(): + native_df = native_pd.DataFrame([1, 2]) + native_df["fg"] = 0 + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by=["fg"], axis=0).apply(lambda x: [1, 2]), + ) + + +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_return_list(): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + native_df = native_pd.DataFrame([[1, 2], [3, 4]]) + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda x: [1, 2], axis=0) + ) + + +@pytest.mark.parametrize( + "apply_func", + [ + lambda x: -x, + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_1), + lambda x: native_pd.Series([3, 4], index=TEST_INDEX_2), + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_WITH_NULL_1), + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_WITH_NULL_1), + ], +) +@sql_count_checker( + query_count=21, + join_count=7, + udtf_count=4, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_multi_index_column_labels(apply_func): + data = [[i + j for j in range(0, 4)] for i in range(0, 4)] + + native_df = native_pd.DataFrame(data, columns=TEST_COLUMNS_1) + snow_df = pd.DataFrame(data, columns=TEST_COLUMNS_1) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(apply_func, axis=0) + ) + + +@sql_count_checker( + query_count=21, + join_count=7, + udtf_count=4, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_multi_index_column_labels_with_different_results(): + data = [[i + j for j in range(0, 4)] for i in range(0, 4)] + + df = native_pd.DataFrame(data, columns=TEST_COLUMNS_1) + snow_df = pd.DataFrame(df) + + apply_func = ( + lambda x: native_pd.Series([1, 2], index=TEST_INDEX_1) + if min(x) == 0 + else native_pd.Series([3, 4], index=TEST_INDEX_2) + ) + + eval_snowpark_pandas_result(snow_df, df, lambda df: df.apply(apply_func, axis=0)) + + +@pytest.mark.parametrize( + "data, func, expected_result", + [ + [ + [ + [datetime.date(2023, 1, 1), None], + [datetime.date(2022, 12, 31), datetime.date(2021, 1, 9)], + ], + lambda x: x.dt.day, + native_pd.DataFrame([[1, np.nan], [31, 9.0]]), + ], + [ + [ + [datetime.time(1, 2, 3), None], + [datetime.time(1, 2, 3, 1), datetime.time(1)], + ], + lambda x: x.dt.seconds, + native_pd.DataFrame([[3723, np.nan], [3723, 3600]]), + ], + [ + [ + [datetime.datetime(2023, 1, 1, 1, 2, 3), None], + [ + datetime.datetime(2022, 12, 31, 1, 2, 3, 1), + datetime.datetime( + 2023, 1, 1, 1, 2, 3, tzinfo=datetime.timezone.utc + ), + ], + ], + lambda x: x.astype(str), + native_pd.DataFrame( + [ + ["2023-01-01 01:02:03.000000", "NaT"], + ["2022-12-31 01:02:03.000001", "2023-01-01 01:02:03+00:00"], + ] + ), + ], + ], +) +@sql_count_checker( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_date_time_timestamp_type(data, func, expected_result): + snow_df = pd.DataFrame(data) + result = snow_df.apply(func, axis=0) + + assert_snowpark_pandas_equal_to_pandas(result, expected_result) + + +@pytest.mark.parametrize( + "native_df, func", + [ + ( + native_pd.DataFrame([[1, 2], [3, 4]], index=["a", "b"]), + lambda x: x["a"] + x["b"], + ), + ( + native_pd.DataFrame( + [[1, 5], [2, 6], [3, 7], [4, 8]], + index=native_pd.MultiIndex.from_tuples( + [("baz", "A"), ("baz", "B"), ("zoo", "A"), ("zoo", "B")] + ), + ), + lambda x: x["baz", "B"] * x["zoo", "A"], + ), + ], +) +@sql_count_checker( + query_count=11, + join_count=2, + udtf_count=2, + union_count=1, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_index_labels(native_df, func): + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result(snow_df, native_df, lambda x: x.apply(func, axis=0)) + + +@sql_count_checker( + query_count=11, + join_count=2, + udtf_count=2, + union_count=1, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", +) +def test_axis_0_raw(): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + native_df = native_pd.DataFrame([[1, 2], [3, 4]]) + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(lambda x: str(type(x)), axis=0, raw=True) + ) + + +def test_axis_0_apply_args_kwargs(): + def f(x, y, z=1) -> int: + return x.sum() + y + z + + native_df = native_pd.DataFrame([[1, 2], [3, 4]]) + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + + with SqlCounter(query_count=3): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(f, axis=0), + expect_exception=True, + expect_exception_type=SnowparkSQLException, + expect_exception_match="missing 1 required positional argument", + assert_exception_equal=False, + ) + + with SqlCounter( + query_count=11, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(f, axis=0, args=(1,)) + ) + + with SqlCounter( + query_count=11, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result( + snow_df, native_df, lambda x: x.apply(f, axis=0, args=(1,), z=2) + ) + + with SqlCounter(query_count=3): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(f, axis=0, args=(1,), z=2, v=3), + expect_exception=True, + expect_exception_type=SnowparkSQLException, + expect_exception_match="got an unexpected keyword argument", + assert_exception_equal=False, + ) + + +@pytest.mark.parametrize("data", [{"a": [1], "b": [2]}, {"a": [2], "b": [3]}]) +def test_apply_axis_0_with_if_where_duplicates_not_executed(data): + df = native_pd.DataFrame(data) + snow_df = pd.DataFrame(df) + + def foo(x): + return native_pd.Series( + [1, 2, 3], index=["C", "A", "E"] if x.sum() > 3 else ["A", "E", "E"] + ) + + with SqlCounter( + query_count=11, + join_count=3, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + eval_snowpark_pandas_result(snow_df, df, lambda x: x.apply(foo, axis=0)) + + +@pytest.mark.parametrize( + "return_value", + [ + native_pd.Series(["a", np.int64(3)]), + ["a", np.int64(3)], + np.int64(3), + ], +) +@sql_count_checker(query_count=6, join_count=1, udtf_count=1) +def test_numpy_integers_in_return_values_snow_1227264(return_value): + eval_snowpark_pandas_result( + *create_test_dfs(["a"]), lambda df: df.apply(lambda row: return_value, axis=0) + ) + + +@pytest.mark.xfail( + strict=True, + raises=SnowparkSQLException, + reason="SNOW-1650918: Apply on dataframe data columns containing NULL fails with invalid arguments to udtf function", +) +@pytest.mark.parametrize( + "data, apply_func", + [ + [ + [[None, "abcd"]], + lambda x: x + " are first 4 letters of alphabet" if x is not None else None, + ], + [ + [[123, None]], + lambda x: x + 100 if x is not None else None, + ], + ], +) +def test_apply_axis_0_bug_1650918(data, apply_func): + native_df = native_pd.DataFrame(data) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda x: x.apply(apply_func, axis=0), + ) + + +def test_apply_nested_series_negative(): + snow_df = pd.DataFrame([[1, 2], [3, 4]]) + + with SqlCounter( + query_count=10, + join_count=2, + udtf_count=2, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + with pytest.raises( + NotImplementedError, + match=r"Nested pd.Series in result is not supported in DataFrame.apply\(axis=0\)", + ): + snow_df.apply( + lambda ser: 99 if ser.sum() == 4 else native_pd.Series([1, 2]), axis=0 + ).to_pandas() + + snow_df2 = pd.DataFrame([[1, 2, 3]]) + + with SqlCounter( + query_count=15, + join_count=3, + udtf_count=3, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + with pytest.raises( + NotImplementedError, + match=r"Nested pd.Series in result is not supported in DataFrame.apply\(axis=0\)", + ): + snow_df2.apply( + lambda ser: 99 + if ser.sum() == 2 + else native_pd.Series([100], index=["a"]), + axis=0, + ).to_pandas() + + +import scipy.stats # noqa: E402 + + +@pytest.mark.parametrize( + "packages,expected_query_count", + [ + (["scipy", "numpy"], 26), + (["scipy>1.1", "numpy<2.0"], 26), + # TODO: SNOW-1478188 Re-enable quarantined tests for 8.23 + # [scipy, np], 9), + ], +) +def test_apply_axis0_with_3rd_party_libraries_and_decorator( + packages, expected_query_count +): + data = [[1, 2, 3, 4, 5], [7, -20, 4.0, 7.0, None]] + + with SqlCounter( + query_count=expected_query_count, + high_count_expected=True, + high_count_reason="SNOW-1650644 & SNOW-1345395: Avoid extra caching and repeatedly creating same temp function", + ): + try: + pd.session.custom_package_usage_config["enabled"] = True + pd.session.add_packages(packages) + + df = pd.DataFrame(data) + + def func(row): + return np.dot(row, scipy.stats.norm.pdf(row)) + + snow_ans = df.apply(func, axis=0) + finally: + pd.session.clear_packages() + pd.session.clear_imports() + + # same in native pandas: + native_df = native_pd.DataFrame(data) + native_ans = native_df.apply(func, axis=0) + + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(snow_ans, native_ans) From 7d54d207c336098abbb616356806e39ca43cd313 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Wed, 11 Sep 2024 11:07:01 -0700 Subject: [PATCH 3/3] SNOW-1418543 make local testing thread safe (#2185) --- src/snowflake/snowpark/mock/_connection.py | 311 +++++----- src/snowflake/snowpark/mock/_functions.py | 23 +- src/snowflake/snowpark/mock/_plan.py | 570 +++++++++--------- .../snowpark/mock/_stage_registry.py | 85 +-- .../snowpark/mock/_stored_procedure.py | 194 +++--- src/snowflake/snowpark/mock/_telemetry.py | 32 +- src/snowflake/snowpark/mock/_udf.py | 177 +++--- src/snowflake/snowpark/session.py | 4 +- tests/mock/test_multithreading.py | 335 ++++++++++ 9 files changed, 1091 insertions(+), 640 deletions(-) create mode 100644 tests/mock/test_multithreading.py diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index 9e8d4d0d721..b384931cb89 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -6,6 +6,7 @@ import functools import json import logging +import threading import uuid from copy import copy from decimal import Decimal @@ -91,35 +92,39 @@ def __init__(self, conn: "MockServerConnection") -> None: self.table_registry = {} self.view_registry = {} self.conn = conn + self._lock = self.conn.get_lock() def is_existing_table(self, name: Union[str, Iterable[str]]) -> bool: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - qualified_name = get_fully_qualified_name( - name, current_schema, current_database - ) - return qualified_name in self.table_registry + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + qualified_name = get_fully_qualified_name( + name, current_schema, current_database + ) + return qualified_name in self.table_registry def is_existing_view(self, name: Union[str, Iterable[str]]) -> bool: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - qualified_name = get_fully_qualified_name( - name, current_schema, current_database - ) - return qualified_name in self.view_registry + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + qualified_name = get_fully_qualified_name( + name, current_schema, current_database + ) + return qualified_name in self.view_registry def read_table(self, name: Union[str, Iterable[str]]) -> TableEmulator: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - qualified_name = get_fully_qualified_name( - name, current_schema, current_database - ) - if qualified_name in self.table_registry: - return copy(self.table_registry[qualified_name]) - else: - raise SnowparkLocalTestingException( - f"Object '{name}' does not exist or not authorized." + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + qualified_name = get_fully_qualified_name( + name, current_schema, current_database ) + if qualified_name in self.table_registry: + return copy(self.table_registry[qualified_name]) + else: + raise SnowparkLocalTestingException( + f"Object '{name}' does not exist or not authorized." + ) def write_table( self, @@ -128,127 +133,155 @@ def write_table( mode: SaveMode, column_names: Optional[List[str]] = None, ) -> List[Row]: - for column in table.columns: - if not table[column].sf_type.nullable and table[column].isnull().any(): - raise SnowparkLocalTestingException( - "NULL result in a non-nullable column" - ) - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - table = copy(table) - if mode == SaveMode.APPEND: - if name in self.table_registry: - target_table = self.table_registry[name] - input_schema = table.columns.to_list() - existing_schema = target_table.columns.to_list() - - if not column_names: # append with column_order being index - if len(input_schema) != len(existing_schema): - raise SnowparkLocalTestingException( - f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}" - ) - # temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1 - table.columns = range(table.shape[1]) - target_table.columns = range(target_table.shape[1]) - else: # append with column_order being name - if invalid_cols := set(input_schema) - set(existing_schema): - identifiers = "', '".join( - unquote_if_quoted(id) for id in invalid_cols - ) - raise SnowparkLocalTestingException( - f"table contains invalid identifier '{identifiers}'" - ) - invalid_non_nullable_cols = [] - for missing_col in set(existing_schema) - set(input_schema): - if target_table[missing_col].sf_type.nullable: - table[missing_col] = None - table.sf_types[missing_col] = target_table[ - missing_col - ].sf_type - else: - invalid_non_nullable_cols.append(missing_col) - if invalid_non_nullable_cols: - identifiers = "', '".join( - unquote_if_quoted(id) - for id in invalid_non_nullable_cols - ) - raise SnowparkLocalTestingException( - f"NULL result in a non-nullable column '{identifiers}'" - ) - - self.table_registry[name] = pandas.concat( - [target_table, table], ignore_index=True - ) - self.table_registry[name].columns = existing_schema - self.table_registry[name].sf_types = target_table.sf_types - else: - self.table_registry[name] = table - elif mode == SaveMode.IGNORE: - if name not in self.table_registry: - self.table_registry[name] = table - elif mode == SaveMode.OVERWRITE: - self.table_registry[name] = table - elif mode == SaveMode.ERROR_IF_EXISTS: - if name in self.table_registry: - raise SnowparkLocalTestingException(f"Table {name} already exists") - else: - self.table_registry[name] = table - elif mode == SaveMode.TRUNCATE: - if name in self.table_registry: - target_table = self.table_registry[name] - input_schema = set(table.columns.to_list()) - existing_schema = set(target_table.columns.to_list()) - # input is a subset of existing schema and all missing columns are nullable - if input_schema.issubset(existing_schema) and all( - target_table[col].sf_type.nullable - for col in set(existing_schema - input_schema) + with self._lock: + for column in table.columns: + if ( + not table[column].sf_type.nullable + and table[column].isnull().any() ): - for col in set(existing_schema - input_schema): - table[col] = ColumnEmulator( - data=[None] * table.shape[0], - sf_type=target_table[col].sf_type, - dtype=object, - ) + raise SnowparkLocalTestingException( + "NULL result in a non-nullable column" + ) + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + table = copy(table) + if mode == SaveMode.APPEND: + if name in self.table_registry: + target_table = self.table_registry[name] + input_schema = table.columns.to_list() + existing_schema = target_table.columns.to_list() + + if not column_names: # append with column_order being index + if len(input_schema) != len(existing_schema): + raise SnowparkLocalTestingException( + f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}" + ) + # temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1 + table.columns = range(table.shape[1]) + target_table.columns = range(target_table.shape[1]) + else: # append with column_order being name + if invalid_cols := set(input_schema) - set(existing_schema): + identifiers = "', '".join( + unquote_if_quoted(id) for id in invalid_cols + ) + raise SnowparkLocalTestingException( + f"table contains invalid identifier '{identifiers}'" + ) + invalid_non_nullable_cols = [] + for missing_col in set(existing_schema) - set(input_schema): + if target_table[missing_col].sf_type.nullable: + table[missing_col] = None + table.sf_types[missing_col] = target_table[ + missing_col + ].sf_type + else: + invalid_non_nullable_cols.append(missing_col) + if invalid_non_nullable_cols: + identifiers = "', '".join( + unquote_if_quoted(id) + for id in invalid_non_nullable_cols + ) + raise SnowparkLocalTestingException( + f"NULL result in a non-nullable column '{identifiers}'" + ) + + self.table_registry[name] = pandas.concat( + [target_table, table], ignore_index=True + ) + self.table_registry[name].columns = existing_schema + self.table_registry[name].sf_types = target_table.sf_types else: + self.table_registry[name] = table + elif mode == SaveMode.IGNORE: + if name not in self.table_registry: + self.table_registry[name] = table + elif mode == SaveMode.OVERWRITE: + self.table_registry[name] = table + elif mode == SaveMode.ERROR_IF_EXISTS: + if name in self.table_registry: raise SnowparkLocalTestingException( - f"Cannot truncate because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}" + f"Table {name} already exists" ) - table.sf_types_by_col_index = target_table.sf_types_by_col_index - table = table.reindex(columns=target_table.columns) - self.table_registry[name] = table - else: - raise SnowparkLocalTestingException(f"Unrecognized mode: {mode}") - return [ - Row(status=f"Table {name} successfully created.") - ] # TODO: match message + else: + self.table_registry[name] = table + elif mode == SaveMode.TRUNCATE: + if name in self.table_registry: + target_table = self.table_registry[name] + input_schema = set(table.columns.to_list()) + existing_schema = set(target_table.columns.to_list()) + # input is a subset of existing schema and all missing columns are nullable + if input_schema.issubset(existing_schema) and all( + target_table[col].sf_type.nullable + for col in set(existing_schema - input_schema) + ): + for col in set(existing_schema - input_schema): + table[col] = ColumnEmulator( + data=[None] * table.shape[0], + sf_type=target_table[col].sf_type, + dtype=object, + ) + else: + raise SnowparkLocalTestingException( + f"Cannot truncate because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}" + ) + table.sf_types_by_col_index = target_table.sf_types_by_col_index + table = table.reindex(columns=target_table.columns) + self.table_registry[name] = table + else: + raise SnowparkLocalTestingException(f"Unrecognized mode: {mode}") + return [ + Row(status=f"Table {name} successfully created.") + ] # TODO: match message def drop_table(self, name: Union[str, Iterable[str]]) -> None: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - if name in self.table_registry: - self.table_registry.pop(name) + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + if name in self.table_registry: + self.table_registry.pop(name) def create_or_replace_view( self, execution_plan: MockExecutionPlan, name: Union[str, Iterable[str]] ): - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - self.view_registry[name] = execution_plan + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + self.view_registry[name] = execution_plan def get_review(self, name: Union[str, Iterable[str]]) -> MockExecutionPlan: - current_schema = self.conn._get_current_parameter("schema") - current_database = self.conn._get_current_parameter("database") - name = get_fully_qualified_name(name, current_schema, current_database) - if name in self.view_registry: - return self.view_registry[name] - raise SnowparkLocalTestingException(f"View {name} does not exist") + with self._lock: + current_schema = self.conn._get_current_parameter("schema") + current_database = self.conn._get_current_parameter("database") + name = get_fully_qualified_name(name, current_schema, current_database) + if name in self.view_registry: + return self.view_registry[name] + raise SnowparkLocalTestingException(f"View {name} does not exist") + + def read_view_if_exists( + self, name: Union[str, Iterable[str]] + ) -> Optional[MockExecutionPlan]: + """Method to atomically read a view if it exists. Returns None if the view does not exist.""" + with self._lock: + if self.is_existing_view(name): + return self.get_review(name) + return None + + def read_table_if_exists( + self, name: Union[str, Iterable[str]] + ) -> Optional[TableEmulator]: + """Method to atomically read a table if it exists. Returns None if the table does not exist.""" + with self._lock: + if self.is_existing_table(name): + return self.read_table(name) + return None def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self._conn = MockedSnowflakeConnection() self._cursor = Mock() + self._lock = threading.RLock() self._lower_case_parameters = {} self.remove_query_listener = Mock() self.add_query_listener = Mock() @@ -301,7 +334,7 @@ def log_not_supported_error( warning_logger: Optional[logging.Logger] = None, ): """ - send telemetry to oob servie, can raise error or logging a warning based upon the input + send telemetry to oob service, can raise error or logging a warning based upon the input Args: external_feature_name: customer facing feature name, this information is used to raise error @@ -323,25 +356,31 @@ def log_not_supported_error( def _get_client_side_session_parameter(self, name: str, default_value: Any) -> Any: # mock implementation - return ( - self._conn._session_parameters.get(name, default_value) - if self._conn._session_parameters - else default_value - ) + with self._lock: + return ( + self._conn._session_parameters.get(name, default_value) + if self._conn._session_parameters + else default_value + ) def get_session_id(self) -> int: return 1 + def get_lock(self): + return self._lock + def close(self) -> None: - if self._conn: - self._conn.close() + with self._lock: + if self._conn: + self._conn.close() def is_closed(self) -> bool: return self._conn.is_closed() def _get_current_parameter(self, param: str, quoted: bool = True) -> Optional[str]: try: - name = getattr(self, f"_active_{param}", None) + with self._lock: + name = getattr(self, f"_active_{param}", None) if name and len(name) >= 2 and name[0] == name[-1] == '"': # it is a quoted identifier, return the original value return name diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index edf9ffc68b3..3842f6fda34 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -10,6 +10,7 @@ import operator import re import string +import threading from decimal import Decimal from functools import partial, reduce from numbers import Real @@ -130,14 +131,17 @@ def __call__(self, *args, input_data=None, row_number=None, **kwargs): class MockedFunctionRegistry: _instance = None + _lock_init = threading.Lock() def __init__(self) -> None: self._registry = dict() + self._lock = threading.RLock() @classmethod def get_or_create(cls) -> "MockedFunctionRegistry": - if cls._instance is None: - cls._instance = MockedFunctionRegistry() + with cls._lock_init: + if cls._instance is None: + cls._instance = MockedFunctionRegistry() return cls._instance def get_function( @@ -151,10 +155,11 @@ def get_function( distinct = func.is_distinct func_name = func_name.lower() - if func_name not in self._registry: - return None + with self._lock: + if func_name not in self._registry: + return None - function = self._registry[func_name] + function = self._registry[func_name] return function.distinct if distinct else function @@ -169,7 +174,8 @@ def register( snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__ ) mocked_function = MockedFunction(name, func_implementation, *args, **kwargs) - self._registry[name] = mocked_function + with self._lock: + self._registry[name] = mocked_function return mocked_function def unregister( @@ -180,8 +186,9 @@ def unregister( snowpark_func if isinstance(snowpark_func, str) else snowpark_func.__name__ ) - if name in self._registry: - del self._registry[name] + with self._lock: + if name in self._registry: + del self._registry[name] class LocalTimezone: diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index 11e54802eea..aa86b2598d6 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -357,18 +357,21 @@ def handle_function_expression( current_row=None, ): func = MockedFunctionRegistry.get_or_create().get_function(exp) + connection_lock = analyzer.session._conn.get_lock() if func is None: - current_schema = analyzer.session.get_current_schema() - current_database = analyzer.session.get_current_database() + with connection_lock: + current_schema = analyzer.session.get_current_schema() + current_database = analyzer.session.get_current_database() udf_name = get_fully_qualified_name(exp.name, current_schema, current_database) # If udf name in the registry then this is a udf, not an actual function - if udf_name in analyzer.session.udf._registry: - exp.udf_name = udf_name - return handle_udf_expression( - exp, input_data, analyzer, expr_to_alias, current_row - ) + with connection_lock: + if udf_name in analyzer.session.udf._registry: + exp.udf_name = udf_name + return handle_udf_expression( + exp, input_data, analyzer, expr_to_alias, current_row + ) if exp.api_call_source == "functions.call_udf": raise SnowparkLocalTestingException( @@ -463,9 +466,12 @@ def handle_udf_expression( ): udf_registry = analyzer.session.udf udf_name = exp.udf_name - udf = udf_registry.get_udf(udf_name) + connection_lock = analyzer.session._conn.get_lock() + with connection_lock: + udf = udf_registry.get_udf(udf_name) + udf_imports = udf_registry.get_udf_imports(udf_name) - with ImportContext(udf_registry.get_udf_imports(udf_name)): + with ImportContext(udf_imports): # Resolve handler callable if type(udf.func) is tuple: module_name, handler_name = udf.func @@ -556,6 +562,7 @@ def execute_mock_plan( analyzer = plan.analyzer entity_registry = analyzer.session._conn.entity_registry + connection_lock = analyzer.session._conn.get_lock() if isinstance(source_plan, SnowflakeValues): table = TableEmulator( @@ -728,18 +735,20 @@ def execute_mock_plan( return res_df if isinstance(source_plan, MockSelectableEntity): entity_name = source_plan.entity.name - if entity_registry.is_existing_table(entity_name): - return entity_registry.read_table(entity_name) - elif entity_registry.is_existing_view(entity_name): - execution_plan = entity_registry.get_review(entity_name) + table = entity_registry.read_table_if_exists(entity_name) + if table is not None: + return table + + execution_plan = entity_registry.read_view_if_exists(entity_name) + if execution_plan is not None: res_df = execute_mock_plan(execution_plan, expr_to_alias) return res_df - else: - db_schme_table = parse_table_name(entity_name) - table = ".".join([part.strip("\"'") for part in db_schme_table[:3]]) - raise SnowparkLocalTestingException( - f"Object '{table}' does not exist or not authorized." - ) + + db_schema_table = parse_table_name(entity_name) + table = ".".join([part.strip("\"'") for part in db_schema_table[:3]]) + raise SnowparkLocalTestingException( + f"Object '{table}' does not exist or not authorized." + ) if isinstance(source_plan, Aggregate): child_rf = execute_mock_plan(source_plan.child, expr_to_alias) if ( @@ -1111,28 +1120,30 @@ def outer_join(base_df): ) if isinstance(source_plan, SnowflakeTable): entity_name = source_plan.name - if entity_registry.is_existing_table(entity_name): - return entity_registry.read_table(entity_name) - elif entity_registry.is_existing_view(entity_name): - execution_plan = entity_registry.get_review(entity_name) + table = entity_registry.read_table_if_exists(entity_name) + if table is not None: + return table + + execution_plan = entity_registry.read_view_if_exists(entity_name) + if execution_plan is not None: res_df = execute_mock_plan(execution_plan, expr_to_alias) return res_df - else: - obj_name_tuple = parse_table_name(entity_name) - obj_name = obj_name_tuple[-1] - obj_schema = ( - obj_name_tuple[-2] - if len(obj_name_tuple) > 1 - else analyzer.session.get_current_schema() - ) - obj_database = ( - obj_name_tuple[-3] - if len(obj_name_tuple) > 2 - else analyzer.session.get_current_database() - ) - raise SnowparkLocalTestingException( - f"Object '{obj_database[1:-1]}.{obj_schema[1:-1]}.{obj_name[1:-1]}' does not exist or not authorized." - ) + + obj_name_tuple = parse_table_name(entity_name) + obj_name = obj_name_tuple[-1] + obj_schema = ( + obj_name_tuple[-2] + if len(obj_name_tuple) > 1 + else analyzer.session.get_current_schema() + ) + obj_database = ( + obj_name_tuple[-3] + if len(obj_name_tuple) > 2 + else analyzer.session.get_current_database() + ) + raise SnowparkLocalTestingException( + f"Object '{obj_database[1:-1]}.{obj_schema[1:-1]}.{obj_name[1:-1]}' does not exist or not authorized." + ) if isinstance(source_plan, Sample): res_df = execute_mock_plan(source_plan.child, expr_to_alias) @@ -1159,272 +1170,283 @@ def outer_join(base_df): return from_df if isinstance(source_plan, TableUpdate): - target = entity_registry.read_table(source_plan.table_name) - ROW_ID = "row_id_" + generate_random_alphanumeric() - target.insert(0, ROW_ID, range(len(target))) + # since we are modifying the table, we need to ensure that no other thread + # reads the table until it is updated + with connection_lock: + target = entity_registry.read_table(source_plan.table_name) + ROW_ID = "row_id_" + generate_random_alphanumeric() + target.insert(0, ROW_ID, range(len(target))) + + if source_plan.source_data: + # Calculate cartesian product + source = execute_mock_plan(source_plan.source_data, expr_to_alias) + cartesian_product = target.merge(source, on=None, how="cross") + cartesian_product.sf_types.update(target.sf_types) + cartesian_product.sf_types.update(source.sf_types) + intermediate = cartesian_product + else: + intermediate = target - if source_plan.source_data: - # Calculate cartesian product - source = execute_mock_plan(source_plan.source_data, expr_to_alias) - cartesian_product = target.merge(source, on=None, how="cross") - cartesian_product.sf_types.update(target.sf_types) - cartesian_product.sf_types.update(source.sf_types) - intermediate = cartesian_product - else: - intermediate = target + if source_plan.condition: + # Select rows to be updated based on condition + condition = calculate_expression( + source_plan.condition, intermediate, analyzer, expr_to_alias + ).fillna(value=False) - if source_plan.condition: - # Select rows to be updated based on condition - condition = calculate_expression( - source_plan.condition, intermediate, analyzer, expr_to_alias - ).fillna(value=False) - - matched = target.apply(tuple, 1).isin( - intermediate[condition][target.columns].apply(tuple, 1) + matched = target.apply(tuple, 1).isin( + intermediate[condition][target.columns].apply(tuple, 1) + ) + matched.sf_type = ColumnType(BooleanType(), True) + matched_rows = target[matched] + intermediate = intermediate[condition] + else: + matched_rows = target + + # Calculate multi_join + matched_count = intermediate[target.columns].value_counts(dropna=False)[ + matched_rows.apply(tuple, 1) + ] + multi_joins = matched_count.where(lambda x: x > 1).count() + + # Select rows that match the condition to be updated + rows_to_update = intermediate.drop_duplicates( + subset=matched_rows.columns, keep="first" + ).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update + drop=True ) - matched.sf_type = ColumnType(BooleanType(), True) - matched_rows = target[matched] - intermediate = intermediate[condition] - else: - matched_rows = target + rows_to_update.sf_types = intermediate.sf_types + + # Update rows in place + for attr, new_expr in source_plan.assignments.items(): + column_name = analyzer.analyze(attr, expr_to_alias) + target_index = target.loc[rows_to_update[ROW_ID]].index + new_val = calculate_expression( + new_expr, rows_to_update, analyzer, expr_to_alias + ) + new_val.index = target_index + target.loc[rows_to_update[ROW_ID], column_name] = new_val - # Calculate multi_join - matched_count = intermediate[target.columns].value_counts(dropna=False)[ - matched_rows.apply(tuple, 1) - ] - multi_joins = matched_count.where(lambda x: x > 1).count() + # Delete row_id + target = target.drop(ROW_ID, axis=1) - # Select rows that match the condition to be updated - rows_to_update = intermediate.drop_duplicates( - subset=matched_rows.columns, keep="first" - ).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update - drop=True - ) - rows_to_update.sf_types = intermediate.sf_types - - # Update rows in place - for attr, new_expr in source_plan.assignments.items(): - column_name = analyzer.analyze(attr, expr_to_alias) - target_index = target.loc[rows_to_update[ROW_ID]].index - new_val = calculate_expression( - new_expr, rows_to_update, analyzer, expr_to_alias + # Write result back to table + entity_registry.write_table( + source_plan.table_name, target, SaveMode.OVERWRITE ) - new_val.index = target_index - target.loc[rows_to_update[ROW_ID], column_name] = new_val - - # Delete row_id - target = target.drop(ROW_ID, axis=1) - - # Write result back to table - entity_registry.write_table(source_plan.table_name, target, SaveMode.OVERWRITE) return [Row(len(rows_to_update), multi_joins)] elif isinstance(source_plan, TableDelete): - target = entity_registry.read_table(source_plan.table_name) + # since we are modifying the table, we need to ensure that no other thread + # reads the table until it is updated + with connection_lock: + target = entity_registry.read_table(source_plan.table_name) + + if source_plan.source_data: + # Calculate cartesian product + source = execute_mock_plan(source_plan.source_data, expr_to_alias) + cartesian_product = target.merge(source, on=None, how="cross") + cartesian_product.sf_types.update(target.sf_types) + cartesian_product.sf_types.update(source.sf_types) + intermediate = cartesian_product + else: + intermediate = target + + # Select rows to keep based on condition + if source_plan.condition: + condition = calculate_expression( + source_plan.condition, intermediate, analyzer, expr_to_alias + ).fillna(value=False) + intermediate = intermediate[condition] + matched = target.apply(tuple, 1).isin( + intermediate[target.columns].apply(tuple, 1) + ) + matched.sf_type = ColumnType(BooleanType(), True) + rows_to_keep = target[~matched] + else: + rows_to_keep = target.head(0) - if source_plan.source_data: + # Write rows to keep to table registry + entity_registry.write_table( + source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE + ) + return [Row(len(target) - len(rows_to_keep))] + elif isinstance(source_plan, TableMerge): + # since we are modifying the table, we need to ensure that no other thread + # reads the table until it is updated + with connection_lock: + target = entity_registry.read_table(source_plan.table_name) + ROW_ID = "row_id_" + generate_random_alphanumeric() + SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric() # Calculate cartesian product - source = execute_mock_plan(source_plan.source_data, expr_to_alias) + source = execute_mock_plan(source_plan.source, expr_to_alias) + + # Insert row_id and source row_id + target.insert(0, ROW_ID, range(len(target))) + source.insert(0, SOURCE_ROW_ID, range(len(source))) + cartesian_product = target.merge(source, on=None, how="cross") cartesian_product.sf_types.update(target.sf_types) cartesian_product.sf_types.update(source.sf_types) - intermediate = cartesian_product - else: - intermediate = target - - # Select rows to keep based on condition - if source_plan.condition: - condition = calculate_expression( - source_plan.condition, intermediate, analyzer, expr_to_alias - ).fillna(value=False) - intermediate = intermediate[condition] - matched = target.apply(tuple, 1).isin( - intermediate[target.columns].apply(tuple, 1) + join_condition = calculate_expression( + source_plan.join_expr, cartesian_product, analyzer, expr_to_alias ) - matched.sf_type = ColumnType(BooleanType(), True) - rows_to_keep = target[~matched] - else: - rows_to_keep = target.head(0) + join_result = cartesian_product[join_condition] + join_result.sf_types = cartesian_product.sf_types + + # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if + # (1) A target row is selected to be updated with multiple values OR + # (2) A target row is selected to be both updated and deleted + + inserted_rows = [] + insert_clause_specified = ( + update_clause_specified + ) = delete_clause_specified = False + inserted_row_idx = set() # source_row_id + deleted_row_idx = set() + updated_row_idx = set() + for clause in source_plan.clauses: + if isinstance(clause, UpdateMergeExpression): + update_clause_specified = True + # Select rows to update + if clause.condition: + condition = calculate_expression( + clause.condition, join_result, analyzer, expr_to_alias + ).fillna(value=False) + rows_to_update = join_result[condition] + else: + rows_to_update = join_result - # Write rows to keep to table registry - entity_registry.write_table( - source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE - ) - return [Row(len(target) - len(rows_to_keep))] - elif isinstance(source_plan, TableMerge): - target = entity_registry.read_table(source_plan.table_name) - ROW_ID = "row_id_" + generate_random_alphanumeric() - SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric() - # Calculate cartesian product - source = execute_mock_plan(source_plan.source, expr_to_alias) - - # Insert row_id and source row_id - target.insert(0, ROW_ID, range(len(target))) - source.insert(0, SOURCE_ROW_ID, range(len(source))) - - cartesian_product = target.merge(source, on=None, how="cross") - cartesian_product.sf_types.update(target.sf_types) - cartesian_product.sf_types.update(source.sf_types) - join_condition = calculate_expression( - source_plan.join_expr, cartesian_product, analyzer, expr_to_alias - ) - join_result = cartesian_product[join_condition] - join_result.sf_types = cartesian_product.sf_types - - # TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if - # (1) A target row is selected to be updated with multiple values OR - # (2) A target row is selected to be both updated and deleted - - inserted_rows = [] - insert_clause_specified = ( - update_clause_specified - ) = delete_clause_specified = False - inserted_row_idx = set() # source_row_id - deleted_row_idx = set() - updated_row_idx = set() - for clause in source_plan.clauses: - if isinstance(clause, UpdateMergeExpression): - update_clause_specified = True - # Select rows to update - if clause.condition: - condition = calculate_expression( - clause.condition, join_result, analyzer, expr_to_alias - ).fillna(value=False) - rows_to_update = join_result[condition] - else: - rows_to_update = join_result + rows_to_update = rows_to_update[ + ~rows_to_update[ROW_ID] + .isin(updated_row_idx.union(deleted_row_idx)) + .values + ] - rows_to_update = rows_to_update[ - ~rows_to_update[ROW_ID] - .isin(updated_row_idx.union(deleted_row_idx)) - .values - ] + # Update rows in place + for attr, new_expr in clause.assignments.items(): + column_name = analyzer.analyze(attr, expr_to_alias) + target_index = target.loc[rows_to_update[ROW_ID]].index + new_val = calculate_expression( + new_expr, rows_to_update, analyzer, expr_to_alias + ) + new_val.index = target_index + target.loc[rows_to_update[ROW_ID], column_name] = new_val + + # Update updated row id set + for _, row in rows_to_update.iterrows(): + updated_row_idx.add(row[ROW_ID]) + + elif isinstance(clause, DeleteMergeExpression): + delete_clause_specified = True + # Select rows to delete + if clause.condition: + condition = calculate_expression( + clause.condition, join_result, analyzer, expr_to_alias + ).fillna(value=False) + intermediate = join_result[condition] + else: + intermediate = join_result - # Update rows in place - for attr, new_expr in clause.assignments.items(): - column_name = analyzer.analyze(attr, expr_to_alias) - target_index = target.loc[rows_to_update[ROW_ID]].index - new_val = calculate_expression( - new_expr, rows_to_update, analyzer, expr_to_alias + matched = target.apply(tuple, 1).isin( + intermediate[target.columns].apply(tuple, 1) ) - new_val.index = target_index - target.loc[rows_to_update[ROW_ID], column_name] = new_val - - # Update updated row id set - for _, row in rows_to_update.iterrows(): - updated_row_idx.add(row[ROW_ID]) - - elif isinstance(clause, DeleteMergeExpression): - delete_clause_specified = True - # Select rows to delete - if clause.condition: - condition = calculate_expression( - clause.condition, join_result, analyzer, expr_to_alias - ).fillna(value=False) - intermediate = join_result[condition] - else: - intermediate = join_result + matched.sf_type = ColumnType(BooleanType(), True) - matched = target.apply(tuple, 1).isin( - intermediate[target.columns].apply(tuple, 1) - ) - matched.sf_type = ColumnType(BooleanType(), True) + # Update deleted row id set + for _, row in target[matched].iterrows(): + deleted_row_idx.add(row[ROW_ID]) - # Update deleted row id set - for _, row in target[matched].iterrows(): - deleted_row_idx.add(row[ROW_ID]) + # Delete rows in place + target = target[~matched] - # Delete rows in place - target = target[~matched] + elif isinstance(clause, InsertMergeExpression): + insert_clause_specified = True + # calculate unmatched rows in the source + matched = source.apply(tuple, 1).isin( + join_result[source.columns].apply(tuple, 1) + ) + matched.sf_type = ColumnType(BooleanType(), True) + unmatched_rows_in_source = source[~matched] + + # select unmatched rows that qualify the condition + if clause.condition: + condition = calculate_expression( + clause.condition, + unmatched_rows_in_source, + analyzer, + expr_to_alias, + ).fillna(value=False) + unmatched_rows_in_source = unmatched_rows_in_source[condition] + + # filter out the unmatched rows that have been inserted in previous clauses + unmatched_rows_in_source = unmatched_rows_in_source[ + ~unmatched_rows_in_source[SOURCE_ROW_ID] + .isin(inserted_row_idx) + .values + ] - elif isinstance(clause, InsertMergeExpression): - insert_clause_specified = True - # calculate unmatched rows in the source - matched = source.apply(tuple, 1).isin( - join_result[source.columns].apply(tuple, 1) - ) - matched.sf_type = ColumnType(BooleanType(), True) - unmatched_rows_in_source = source[~matched] + # update inserted row idx set + for _, row in unmatched_rows_in_source.iterrows(): + inserted_row_idx.add(row[SOURCE_ROW_ID]) - # select unmatched rows that qualify the condition - if clause.condition: - condition = calculate_expression( - clause.condition, - unmatched_rows_in_source, - analyzer, - expr_to_alias, - ).fillna(value=False) - unmatched_rows_in_source = unmatched_rows_in_source[condition] - - # filter out the unmatched rows that have been inserted in previous clauses - unmatched_rows_in_source = unmatched_rows_in_source[ - ~unmatched_rows_in_source[SOURCE_ROW_ID] - .isin(inserted_row_idx) - .values - ] + # Calculate rows to insert + rows_to_insert = TableEmulator( + [], columns=target.drop(ROW_ID, axis=1).columns, dtype=object + ) + rows_to_insert.sf_types = target.sf_types + if clause.keys: + # Keep track of specified columns + inserted_columns = set() + for k, v in zip(clause.keys, clause.values): + column_name = analyzer.analyze(k, expr_to_alias) + if column_name not in rows_to_insert.columns: + raise SnowparkLocalTestingException( + f"invalid identifier '{column_name}'" + ) + inserted_columns.add(column_name) + new_val = calculate_expression( + v, unmatched_rows_in_source, analyzer, expr_to_alias + ) + # pandas could do implicit type conversion, e.g. from datetime to timestamp + # reconstructing ColumnEmulator helps preserve the original date type + rows_to_insert[column_name] = ColumnEmulator( + new_val.values, + dtype=object, + sf_type=rows_to_insert[column_name].sf_type, + ) - # update inserted row idx set - for _, row in unmatched_rows_in_source.iterrows(): - inserted_row_idx.add(row[SOURCE_ROW_ID]) + # For unspecified columns, use None as default value + for unspecified_col in set(rows_to_insert.columns).difference( + inserted_columns + ): + rows_to_insert[unspecified_col].replace( + np.nan, None, inplace=True + ) - # Calculate rows to insert - rows_to_insert = TableEmulator( - [], columns=target.drop(ROW_ID, axis=1).columns, dtype=object - ) - rows_to_insert.sf_types = target.sf_types - if clause.keys: - # Keep track of specified columns - inserted_columns = set() - for k, v in zip(clause.keys, clause.values): - column_name = analyzer.analyze(k, expr_to_alias) - if column_name not in rows_to_insert.columns: + else: + if len(clause.values) != len(rows_to_insert.columns): raise SnowparkLocalTestingException( - f"invalid identifier '{column_name}'" + f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}" ) - inserted_columns.add(column_name) - new_val = calculate_expression( - v, unmatched_rows_in_source, analyzer, expr_to_alias - ) - # pandas could do implicit type conversion, e.g. from datetime to timestamp - # reconstructing ColumnEmulator helps preserve the original date type - rows_to_insert[column_name] = ColumnEmulator( - new_val.values, - dtype=object, - sf_type=rows_to_insert[column_name].sf_type, - ) - - # For unspecified columns, use None as default value - for unspecified_col in set(rows_to_insert.columns).difference( - inserted_columns - ): - rows_to_insert[unspecified_col].replace( - np.nan, None, inplace=True - ) - - else: - if len(clause.values) != len(rows_to_insert.columns): - raise SnowparkLocalTestingException( - f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}" - ) - for col, v in zip(rows_to_insert.columns, clause.values): - new_val = calculate_expression( - v, unmatched_rows_in_source, analyzer, expr_to_alias - ) - rows_to_insert[col] = new_val + for col, v in zip(rows_to_insert.columns, clause.values): + new_val = calculate_expression( + v, unmatched_rows_in_source, analyzer, expr_to_alias + ) + rows_to_insert[col] = new_val - inserted_rows.append(rows_to_insert) + inserted_rows.append(rows_to_insert) - # Remove inserted ROW ID column - target = target.drop(ROW_ID, axis=1) + # Remove inserted ROW ID column + target = target.drop(ROW_ID, axis=1) - # Process inserted rows - if inserted_rows: - res = pd.concat([target] + inserted_rows) - res.sf_types = target.sf_types - else: - res = target + # Process inserted rows + if inserted_rows: + res = pd.concat([target] + inserted_rows) + res.sf_types = target.sf_types + else: + res = target - # Write the result back to table - entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE) + # Write the result back to table + entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE) # Generate metadata result res = [] diff --git a/src/snowflake/snowpark/mock/_stage_registry.py b/src/snowflake/snowpark/mock/_stage_registry.py index 7ed55d1cdc6..d4100606821 100644 --- a/src/snowflake/snowpark/mock/_stage_registry.py +++ b/src/snowflake/snowpark/mock/_stage_registry.py @@ -647,30 +647,34 @@ def __init__(self, conn: "MockServerConnection") -> None: self._root_dir = tempfile.TemporaryDirectory() self._stage_registry = {} self._conn = conn + self._lock = conn.get_lock() def create_or_replace_stage(self, stage_name): - self._stage_registry[stage_name] = StageEntity( - self._root_dir.name, stage_name, self._conn - ) + with self._lock: + self._stage_registry[stage_name] = StageEntity( + self._root_dir.name, stage_name, self._conn + ) def __getitem__(self, stage_name: str): # the assumption here is that stage always exists - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - return self._stage_registry[stage_name] + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + return self._stage_registry[stage_name] def put( self, local_file_name: str, stage_location: str, overwrite: bool = False ) -> TableEmulator: stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) # the assumption here is that stage always exists - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - return self._stage_registry[stage_name].put_file( - local_file_name=local_file_name, - stage_prefix=stage_prefix, - overwrite=overwrite, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + return self._stage_registry[stage_name].put_file( + local_file_name=local_file_name, + stage_prefix=stage_prefix, + overwrite=overwrite, + ) def upload_stream( self, @@ -681,14 +685,15 @@ def upload_stream( ) -> Dict: stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) # the assumption here is that stage always exists - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - return self._stage_registry[stage_name].upload_stream( - input_stream=input_stream, - stage_prefix=stage_prefix, - file_name=file_name, - overwrite=overwrite, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + return self._stage_registry[stage_name].upload_stream( + input_stream=input_stream, + stage_prefix=stage_prefix, + file_name=file_name, + overwrite=overwrite, + ) def get( self, @@ -701,14 +706,15 @@ def get( f"Invalid stage {stage_location}, stage name should start with character '@'" ) stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - - return self._stage_registry[stage_name].get_file( - stage_location=stage_prefix, - target_directory=target_directory, - options=options, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + + return self._stage_registry[stage_name].get_file( + stage_location=stage_prefix, + target_directory=target_directory, + options=options, + ) def read_file( self, @@ -723,13 +729,14 @@ def read_file( f"Invalid stage {stage_location}, stage name should start with character '@'" ) stage_name, stage_prefix = extract_stage_name_and_prefix(stage_location) - if stage_name not in self._stage_registry: - self.create_or_replace_stage(stage_name) - - return self._stage_registry[stage_name].read_file( - stage_location=stage_prefix, - format=format, - schema=schema, - analyzer=analyzer, - options=options, - ) + with self._lock: + if stage_name not in self._stage_registry: + self.create_or_replace_stage(stage_name) + + return self._stage_registry[stage_name].read_file( + stage_location=stage_prefix, + format=format, + schema=schema, + analyzer=analyzer, + options=options, + ) diff --git a/src/snowflake/snowpark/mock/_stored_procedure.py b/src/snowflake/snowpark/mock/_stored_procedure.py index d93500da2e8..14abec358c2 100644 --- a/src/snowflake/snowpark/mock/_stored_procedure.py +++ b/src/snowflake/snowpark/mock/_stored_procedure.py @@ -154,9 +154,11 @@ def __init__(self, *args, **kwargs) -> None: ) # maps name to either the callable or a pair of str (module_name, callable_name) self._sproc_level_imports = dict() # maps name to a set of file paths self._session_level_imports = set() + self._lock = self._session._conn.get_lock() def _clear_session_imports(self): - self._session_level_imports.clear() + with self._lock: + self._session_level_imports.clear() def _import_file( self, @@ -172,16 +174,17 @@ def _import_file( imports specified. """ - absolute_module_path, module_name = extract_import_dir_and_module_name( - file_path, self._session._conn.stage_registry, import_path - ) + with self._lock: + absolute_module_path, module_name = extract_import_dir_and_module_name( + file_path, self._session._conn.stage_registry, import_path + ) - if sproc_name: - self._sproc_level_imports[sproc_name].add(absolute_module_path) - else: - self._session_level_imports.add(absolute_module_path) + if sproc_name: + self._sproc_level_imports[sproc_name].add(absolute_module_path) + else: + self._session_level_imports.add(absolute_module_path) - return module_name + return module_name def _do_register_sp( self, @@ -224,90 +227,96 @@ def _do_register_sp( error_message="Registering anonymous sproc is not currently supported.", raise_error=NotImplementedError, ) - ( - sproc_name, - is_pandas_udf, - is_dataframe_input, - return_type, - input_types, - opt_arg_defaults, - ) = process_registration_inputs( - self._session, - TempObjectType.PROCEDURE, - func, - return_type, - input_types, - sp_name, - anonymous, - ) - current_schema = self._session.get_current_schema() - current_database = self._session.get_current_database() - sproc_name = get_fully_qualified_name( - sproc_name, current_schema, current_database - ) - - check_python_runtime_version(self._session._runtime_version_from_requirement) - - if replace and if_not_exists: - raise ValueError("options replace and if_not_exists are incompatible") + with self._lock: + ( + sproc_name, + is_pandas_udf, + is_dataframe_input, + return_type, + input_types, + opt_arg_defaults, + ) = process_registration_inputs( + self._session, + TempObjectType.PROCEDURE, + func, + return_type, + input_types, + sp_name, + anonymous, + ) - if sproc_name in self._registry and if_not_exists: - return self._registry[sproc_name] + current_schema = self._session.get_current_schema() + current_database = self._session.get_current_database() + sproc_name = get_fully_qualified_name( + sproc_name, current_schema, current_database + ) - if sproc_name in self._registry and not replace: - raise SnowparkLocalTestingException( - f"002002 (42710): SQL compilation error: \nObject '{sproc_name}' already exists.", - error_code="1304", + check_python_runtime_version( + self._session._runtime_version_from_requirement ) - if is_pandas_udf: - raise TypeError("pandas stored procedure is not supported") + if replace and if_not_exists: + raise ValueError("options replace and if_not_exists are incompatible") - if packages: - pass # NO-OP + if sproc_name in self._registry and if_not_exists: + return self._registry[sproc_name] - if imports is not None or type(func) is tuple: - self._sproc_level_imports[sproc_name] = set() + if sproc_name in self._registry and not replace: + raise SnowparkLocalTestingException( + f"002002 (42710): SQL compilation error: \nObject '{sproc_name}' already exists.", + error_code="1304", + ) - if imports is not None: - for _import in imports: - if isinstance(_import, str): - self._import_file(_import, sproc_name=sproc_name) - elif isinstance(_import, tuple) and all( - isinstance(item, str) for item in _import - ): - local_path, import_path = _import - self._import_file(local_path, import_path, sproc_name=sproc_name) - else: - raise TypeError( - "stored-proc-level import can only be a file path (str) or a tuple of the file path (str) and the import path (str)" - ) + if is_pandas_udf: + raise TypeError("pandas stored procedure is not supported") - if type(func) is tuple: # register from file - if sproc_name not in self._sproc_level_imports: - self._sproc_level_imports[sproc_name] = set() - module_name = self._import_file(func[0], sproc_name=sproc_name) - func = (module_name, func[1]) + if packages: + pass # NO-OP - if sproc_name in self._sproc_level_imports: - sproc_imports = self._sproc_level_imports[sproc_name] - else: - sproc_imports = copy(self._session_level_imports) + if imports is not None or type(func) is tuple: + self._sproc_level_imports[sproc_name] = set() - sproc = MockStoredProcedure( - func, - return_type, - input_types, - sproc_name, - sproc_imports, - execute_as=execute_as, - strict=strict, - ) + if imports is not None: + for _import in imports: + if isinstance(_import, str): + self._import_file(_import, sproc_name=sproc_name) + elif isinstance(_import, tuple) and all( + isinstance(item, str) for item in _import + ): + local_path, import_path = _import + self._import_file( + local_path, import_path, sproc_name=sproc_name + ) + else: + raise TypeError( + "stored-proc-level import can only be a file path (str) or a tuple of the file path (str) and the import path (str)" + ) + + if type(func) is tuple: # register from file + if sproc_name not in self._sproc_level_imports: + self._sproc_level_imports[sproc_name] = set() + module_name = self._import_file(func[0], sproc_name=sproc_name) + func = (module_name, func[1]) + + if sproc_name in self._sproc_level_imports: + sproc_imports = self._sproc_level_imports[sproc_name] + else: + sproc_imports = copy(self._session_level_imports) + + sproc = MockStoredProcedure( + func, + return_type, + input_types, + sproc_name, + sproc_imports, + execute_as=execute_as, + strict=strict, + ) - self._registry[sproc_name] = sproc + self._registry[sproc_name] = sproc - return sproc + return sproc def call( self, @@ -316,17 +325,18 @@ def call( session: Optional["snowflake.snowpark.session.Session"] = None, statement_params: Optional[Dict[str, str]] = None, ): - current_schema = self._session.get_current_schema() - current_database = self._session.get_current_database() - sproc_name = get_fully_qualified_name( - sproc_name, current_schema, current_database - ) - - if sproc_name not in self._registry: - raise SnowparkLocalTestingException( - f"Unknown function {sproc_name}. Stored procedure by that name does not exist." + with self._lock: + current_schema = self._session.get_current_schema() + current_database = self._session.get_current_database() + sproc_name = get_fully_qualified_name( + sproc_name, current_schema, current_database ) - return self._registry[sproc_name]( - *args, session=session, statement_params=statement_params - ) + if sproc_name not in self._registry: + raise SnowparkLocalTestingException( + f"Unknown function {sproc_name}. Stored procedure by that name does not exist." + ) + + sproc = self._registry[sproc_name] + + return sproc(*args, session=session, statement_params=statement_params) diff --git a/src/snowflake/snowpark/mock/_telemetry.py b/src/snowflake/snowpark/mock/_telemetry.py index 857291b47fd..6e4273aa7ff 100644 --- a/src/snowflake/snowpark/mock/_telemetry.py +++ b/src/snowflake/snowpark/mock/_telemetry.py @@ -5,6 +5,7 @@ import json import logging import os +import threading import uuid from datetime import datetime from enum import Enum @@ -92,6 +93,7 @@ def __init__(self) -> None: ) self._deployment_url = self.PROD self._enable = True + self._lock = threading.RLock() def _upload_payload(self, payload) -> None: if not REQUESTS_AVAILABLE: @@ -136,12 +138,25 @@ def add(self, event) -> None: if not self.enabled: return - self.queue.put(event) - if self.queue.qsize() > self.batch_size: - payload = self.export_queue_to_string() - if payload is None: - return - self._upload_payload(payload) + with self._lock: + self.queue.put(event) + if self.queue.qsize() > self.batch_size: + payload = self.export_queue_to_string() + if payload is None: + return + self._upload_payload(payload) + + def flush(self) -> None: + """Flushes all telemetry events in the queue and submit them to the back-end.""" + if not self.enabled: + return + + with self._lock: + if not self.queue.empty(): + payload = self.export_queue_to_string() + if payload is None: + return + self._upload_payload(payload) @property def enabled(self) -> bool: @@ -158,8 +173,9 @@ def disable(self) -> None: def export_queue_to_string(self): logs = list() - while not self.queue.empty(): - logs.append(self.queue.get()) + with self._lock: + while not self.queue.empty(): + logs.append(self.queue.get()) # We may get an exception trying to serialize a python object to JSON try: payload = json.dumps(logs) diff --git a/src/snowflake/snowpark/mock/_udf.py b/src/snowflake/snowpark/mock/_udf.py index 7cedf0de660..a7a17d9a030 100644 --- a/src/snowflake/snowpark/mock/_udf.py +++ b/src/snowflake/snowpark/mock/_udf.py @@ -38,9 +38,11 @@ def __init__(self, *args, **kwargs) -> None: dict() ) # maps udf name to either the callable or a pair of str (module_name, callable_name) self._session_level_imports = set() + self._lock = self._session._conn.get_lock() def _clear_session_imports(self): - self._session_level_imports.clear() + with self._lock: + self._session_level_imports.clear() def _import_file( self, @@ -54,29 +56,32 @@ def _import_file( When udf_name is not None, the import is added to the UDF associated with the name; Otherwise, it is a session level import and will be used if no UDF-level imports are specified. """ - absolute_module_path, module_name = extract_import_dir_and_module_name( - file_path, self._session._conn.stage_registry, import_path - ) - if udf_name: - self._registry[udf_name].add_import(absolute_module_path) - else: - self._session_level_imports.add(absolute_module_path) + with self._lock: + absolute_module_path, module_name = extract_import_dir_and_module_name( + file_path, self._session._conn.stage_registry, import_path + ) + if udf_name: + self._registry[udf_name].add_import(absolute_module_path) + else: + self._session_level_imports.add(absolute_module_path) - return module_name + return module_name def get_udf(self, udf_name: str) -> MockUserDefinedFunction: - if udf_name not in self._registry: - raise SnowparkLocalTestingException(f"udf {udf_name} does not exist.") - return self._registry[udf_name] + with self._lock: + if udf_name not in self._registry: + raise SnowparkLocalTestingException(f"udf {udf_name} does not exist.") + return self._registry[udf_name] def get_udf_imports(self, udf_name: str) -> Set[str]: - udf = self._registry.get(udf_name) - if not udf: - return set() - elif udf.use_session_imports: - return self._session_level_imports - else: - return udf._imports + with self._lock: + udf = self._registry.get(udf_name) + if not udf: + return set() + elif udf.use_session_imports: + return self._session_level_imports + else: + return udf._imports def _do_register_udf( self, @@ -113,73 +118,81 @@ def _do_register_udf( raise_error=NotImplementedError, ) - # get the udf name, return and input types - ( - udf_name, - is_pandas_udf, - is_dataframe_input, - return_type, - input_types, - opt_arg_defaults, - ) = process_registration_inputs( - self._session, TempObjectType.FUNCTION, func, return_type, input_types, name - ) - - current_schema = self._session.get_current_schema() - current_database = self._session.get_current_database() - udf_name = get_fully_qualified_name(udf_name, current_schema, current_database) - - # allow registering pandas UDF from udf(), - # but not allow registering non-pandas UDF from pandas_udf() - if from_pandas_udf_function and not is_pandas_udf: - raise ValueError( - "You cannot create a non-vectorized UDF using pandas_udf(). " - "Use udf() instead." + with self._lock: + # get the udf name, return and input types + ( + udf_name, + is_pandas_udf, + is_dataframe_input, + return_type, + input_types, + opt_arg_defaults, + ) = process_registration_inputs( + self._session, + TempObjectType.FUNCTION, + func, + return_type, + input_types, + name, ) - custom_python_runtime_version_allowed = False + current_schema = self._session.get_current_schema() + current_database = self._session.get_current_database() + udf_name = get_fully_qualified_name( + udf_name, current_schema, current_database + ) - if not custom_python_runtime_version_allowed: - check_python_runtime_version( - self._session._runtime_version_from_requirement + # allow registering pandas UDF from udf(), + # but not allow registering non-pandas UDF from pandas_udf() + if from_pandas_udf_function and not is_pandas_udf: + raise ValueError( + "You cannot create a non-vectorized UDF using pandas_udf(). " + "Use udf() instead." + ) + + custom_python_runtime_version_allowed = False + + if not custom_python_runtime_version_allowed: + check_python_runtime_version( + self._session._runtime_version_from_requirement + ) + + if replace and if_not_exists: + raise ValueError("options replace and if_not_exists are incompatible") + + if udf_name in self._registry and if_not_exists: + return self._registry[udf_name] + + if udf_name in self._registry and not replace: + raise SnowparkSQLException( + f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.", + error_code="1304", + ) + + if packages: + pass # NO-OP + + # register + self._registry[udf_name] = MockUserDefinedFunction( + func, + return_type, + input_types, + udf_name, + strict=strict, + packages=packages, + use_session_imports=imports is None, ) - if replace and if_not_exists: - raise ValueError("options replace and if_not_exists are incompatible") + if type(func) is tuple: # update file registration + module_name = self._import_file(func[0], udf_name=udf_name) + self._registry[udf_name].func = (module_name, func[1]) - if udf_name in self._registry and if_not_exists: - return self._registry[udf_name] + if imports is not None: + for _import in imports: + if type(_import) is str: + self._import_file(_import, udf_name=udf_name) + else: + local_path, import_path = _import + self._import_file(local_path, import_path, udf_name=udf_name) - if udf_name in self._registry and not replace: - raise SnowparkSQLException( - f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.", - error_code="1304", - ) - - if packages: - pass # NO-OP - - # register - self._registry[udf_name] = MockUserDefinedFunction( - func, - return_type, - input_types, - udf_name, - strict=strict, - packages=packages, - use_session_imports=imports is None, - ) - - if type(func) is tuple: # update file registration - module_name = self._import_file(func[0], udf_name=udf_name) - self._registry[udf_name].func = (module_name, func[1]) - - if imports is not None: - for _import in imports: - if type(_import) is str: - self._import_file(_import, udf_name=udf_name) - else: - local_path, import_path = _import - self._import_file(local_path, import_path, udf_name=udf_name) - - return self._registry[udf_name] + return self._registry[udf_name] diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a04e381e985..8da0794f139 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3093,7 +3093,9 @@ def _use_object(self, object_name: str, object_type: str) -> None: # we do not validate here object_type = match.group(1) object_name = match.group(2) - setattr(self._conn, f"_active_{object_type}", object_name) + mock_conn_lock = self._conn.get_lock() + with mock_conn_lock: + setattr(self._conn, f"_active_{object_type}", object_name) else: self._run_query(query) else: diff --git a/tests/mock/test_multithreading.py b/tests/mock/test_multithreading.py new file mode 100644 index 00000000000..bae771e8e77 --- /dev/null +++ b/tests/mock/test_multithreading.py @@ -0,0 +1,335 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import io +import json +import os +import tempfile +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Thread + +import pytest + +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + LogicalPlan, + SaveMode, +) +from snowflake.snowpark._internal.utils import normalize_local_file +from snowflake.snowpark.functions import lit, when_matched +from snowflake.snowpark.mock._connection import MockServerConnection +from snowflake.snowpark.mock._functions import MockedFunctionRegistry +from snowflake.snowpark.mock._plan import MockExecutionPlan +from snowflake.snowpark.mock._snowflake_data_type import TableEmulator +from snowflake.snowpark.mock._stage_registry import StageEntityRegistry +from snowflake.snowpark.mock._telemetry import LocalTestOOBTelemetryService +from snowflake.snowpark.row import Row +from snowflake.snowpark.session import Session +from tests.utils import Utils + + +def test_table_update_merge_delete(session): + table_name = Utils.random_table_name() + num_threads = 10 + data = [[v, 11 * v] for v in range(10)] + df = session.create_dataframe(data, schema=["a", "b"]) + df.write.save_as_table(table_name, table_type="temp") + + source_df = df + t = session.table(table_name) + + def update_table(thread_id: int): + t.update({"b": 0}, t.a == lit(thread_id)) + + def merge_table(thread_id: int): + t.merge( + source_df, t.a == source_df.a, [when_matched().update({"b": source_df.b})] + ) + + def delete_table(thread_id: int): + t.delete(t.a == lit(thread_id)) + + # all threads will update column b to 0 where a = thread_id + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # update + futures = [executor.submit(update_table, i) for i in range(num_threads)] + for future in as_completed(futures): + future.result() + + # all threads will set column b to 0 + Utils.check_answer(t.select(t.b), [Row(B=0) for _ in range(10)]) + + # merge + futures = [executor.submit(merge_table, i) for i in range(num_threads)] + for future in as_completed(futures): + future.result() + + # all threads will set column b to 11 * a + Utils.check_answer(t.select(t.b), [Row(B=11 * i) for i in range(10)]) + + # delete + futures = [executor.submit(delete_table, i) for i in range(num_threads)] + for future in as_completed(futures): + future.result() + + # all threads will delete their row + assert t.count() == 0 + + +def test_udf_register_and_invoke(session): + df = session.create_dataframe([[1], [2]], schema=["num"]) + num_threads = 10 + + def register_udf(x: int): + def echo(x: int) -> int: + return x + + return session.udf.register(echo, name="echo", replace=True) + + def invoke_udf(): + result = df.select(session.udf.call_udf("echo", df.num)).collect() + assert result[0][0] == 1 + assert result[1][0] == 2 + + threads = [] + for i in range(num_threads): + thread_register = Thread(target=register_udf, args=(i,)) + threads.append(thread_register) + thread_register.start() + + thread_invoke = Thread(target=invoke_udf) + threads.append(thread_invoke) + thread_invoke.start() + + for thread in threads: + thread.join() + + +def test_sp_register_and_invoke(session): + num_threads = 10 + + def increment_by_one_fn(session_: Session, x: int) -> int: + return x + 1 + + def register_sproc(): + session.sproc.register( + increment_by_one_fn, name="increment_by_one", replace=True + ) + + def invoke_sproc(): + result = session.call("increment_by_one", 1) + assert result == 2 + + threads = [] + for i in range(num_threads): + thread_register = Thread(target=register_sproc, args=(i,)) + threads.append(thread_register) + thread_register.start() + + thread_invoke = Thread(target=invoke_sproc) + threads.append(thread_invoke) + thread_invoke.start() + + for thread in threads: + thread.join() + + +def test_mocked_function_registry_created_once(): + num_threads = 10 + + result = [] + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(MockedFunctionRegistry.get_or_create) + for _ in range(num_threads) + ] + + for future in as_completed(futures): + result.append(future.result()) + + registry = MockedFunctionRegistry.get_or_create() + assert all([registry is r for r in result]) + + +@pytest.mark.parametrize("test_table", [True, False]) +def test_tabular_entity_registry(test_table): + conn = MockServerConnection() + entity_registry = conn.entity_registry + num_threads = 10 + + def write_read_and_drop_table(): + table_name = "test_table" + table_emulator = TableEmulator() + + entity_registry.write_table(table_name, table_emulator, SaveMode.OVERWRITE) + + optional_table = entity_registry.read_table_if_exists(table_name) + if optional_table is not None: + assert optional_table.empty + + entity_registry.drop_table(table_name) + + def write_read_and_drop_view(): + view_name = "test_view" + empty_logical_plan = LogicalPlan() + plan = MockExecutionPlan(empty_logical_plan, None) + + entity_registry.create_or_replace_view(plan, view_name) + + optional_view = entity_registry.read_view_if_exists(view_name) + if optional_view is not None: + assert optional_view.source_plan == empty_logical_plan + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + if test_table: + test_fn = write_read_and_drop_table + else: + test_fn = write_read_and_drop_view + futures = [executor.submit(test_fn) for _ in range(num_threads)] + + for future in as_completed(futures): + future.result() + + +def test_stage_entity_registry_put_and_get(): + stage_registry = StageEntityRegistry(MockServerConnection()) + num_threads = 10 + + def put_and_get_file(): + stage_registry.put( + normalize_local_file( + f"{os.path.dirname(os.path.abspath(__file__))}/files/test_file_1" + ), + "@test_stage/test_parent_dir/test_child_dir", + ) + with tempfile.TemporaryDirectory() as temp_dir: + stage_registry.get( + "@test_stage/test_parent_dir/test_child_dir/test_file_1", + temp_dir, + ) + assert os.path.isfile(os.path.join(temp_dir, "test_file_1")) + + threads = [] + for _ in range(num_threads): + thread = Thread(target=put_and_get_file) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + +def test_stage_entity_registry_upload_and_read(session): + stage_registry = StageEntityRegistry(MockServerConnection()) + num_threads = 10 + + def upload_and_read_json(thread_id: int): + json_string = json.dumps({"thread_id": thread_id}) + bytes_io = io.BytesIO(json_string.encode("utf-8")) + stage_registry.upload_stream( + input_stream=bytes_io, + stage_location="@test_stage/test_parent_dir", + file_name=f"test_file_{thread_id}", + ) + + df = stage_registry.read_file( + f"@test_stage/test_parent_dir/test_file_{thread_id}", + "json", + [], + session._analyzer, + {"INFER_SCHEMA": "True"}, + ) + + assert df['"thread_id"'].iloc[0] == thread_id + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(upload_and_read_json, i) for i in range(num_threads)] + + for future in as_completed(futures): + future.result() + + +def test_stage_entity_registry_create_or_replace(): + stage_registry = StageEntityRegistry(MockServerConnection()) + num_threads = 10 + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(stage_registry.create_or_replace_stage, f"test_stage_{i}") + for i in range(num_threads) + ] + + for future in as_completed(futures): + future.result() + + assert len(stage_registry._stage_registry) == num_threads + for i in range(num_threads): + assert f"test_stage_{i}" in stage_registry._stage_registry + + +def test_oob_telemetry_add(): + oob_service = LocalTestOOBTelemetryService.get_instance() + # clean up queue first + oob_service.export_queue_to_string() + num_threads = 10 + num_events_per_thread = 10 + + # create a function that adds 10 events to the queue + def add_events(thread_id: int): + for i in range(num_events_per_thread): + oob_service.add( + {f"thread_{thread_id}_event_{i}": f"dummy_event_{thread_id}_{i}"} + ) + + # set batch_size to 101 + is_enabled = oob_service.enabled + oob_service.enable() + original_batch_size = oob_service.batch_size + oob_service.batch_size = num_threads * num_events_per_thread + 1 + try: + # create 10 threads + threads = [] + for thread_id in range(num_threads): + thread = Thread(target=add_events, args=(thread_id,)) + threads.append(thread) + thread.start() + + # wait for all threads to finish + for thread in threads: + thread.join() + + # assert that the queue size is 100 + assert oob_service.queue.qsize() == num_threads * num_events_per_thread + finally: + oob_service.batch_size = original_batch_size + if not is_enabled: + oob_service.disable() + + +def test_oob_telemetry_flush(): + oob_service = LocalTestOOBTelemetryService.get_instance() + # clean up queue first + oob_service.export_queue_to_string() + + is_enabled = oob_service.enabled + oob_service.enable() + # add a dummy event + oob_service.add({"event": "dummy_event"}) + + try: + # flush the queue in multiple threads + num_threads = 10 + threads = [] + for _ in range(num_threads): + thread = Thread(target=oob_service.flush) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # assert that the queue is empty + assert oob_service.size() == 0 + finally: + if not is_enabled: + oob_service.disable()