diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index 20a2d5274a2..0d8ca77dac3 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -85,7 +85,7 @@ jobs: matrix: os: [macos-latest, windows-latest-64-cores, ubuntu-latest-64-cores] python-version: ["3.9", "3.10", "3.11"] - cloud-provider: [aws, azure, gcp] + cloud-provider: [aws, gcp] # TODO: SNOW-1643374 add azure back exclude: # only run macos with aws py3.9 for doctest - os: macos-latest @@ -309,7 +309,7 @@ jobs: matrix: os: [macos-latest, windows-latest-64-cores, ubuntu-latest-64-cores] python-version: [ "3.9", "3.10", "3.11" ] - cloud-provider: [aws, azure, gcp] + cloud-provider: [aws, gcp] # TODO: SNOW-1643374 add azure back exclude: # only run macos with aws py3.9 for doctest - os: macos-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a4c9b8d5ee..1cf5845439c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,9 +4,14 @@ ### Snowpark Python API Updates +### New Features + +- Added following new functions in `snowflake.snowpark.functions`: + - `array_remove` + - `ln` + #### Improvements -- Added support for function `functions.ln` - Added support for specifying the following to `DataFrameWriter.save_as_table`: - `enable_schema_evolution` - `data_retention_time` @@ -45,13 +50,15 @@ #### New Features -- Added limited support for the `Timedelta` type, including +- Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases. - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`. - - converting non-timedelta to timedelta via `astype`. + - converting non-timedelta to timedelta via `astype`. - `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`. - support for subtracting two timestamps to get a Timedelta. - - support indexing with Timedelta data columns. + - support indexing with Timedelta data columns. - support for adding or subtracting timestamps and `Timedelta`. + - support for binary arithmetic between two `Timedelta` values. + - support for lazy `TimedeltaIndex`. - Added support for index's arithmetic and comparison operators. - Added support for `Series.dt.round`. - Added documentation pages for `DatetimeIndex`. diff --git a/docs/source/modin/indexing.rst b/docs/source/modin/indexing.rst index fa4f0538890..80ceba61bec 100644 --- a/docs/source/modin/indexing.rst +++ b/docs/source/modin/indexing.rst @@ -220,3 +220,43 @@ DatetimeIndex DatetimeIndex.mean DatetimeIndex.std + +.. _api.timedeltaindex: + +TimedeltaIndex +-------------- + +.. autosummary:: + :toctree: pandas_api/ + + TimedeltaIndex + +.. rubric:: `TimedeltaIndex` Components + +.. autosummary:: + :toctree: pandas_api/ + + TimedeltaIndex.days + TimedeltaIndex.seconds + TimedeltaIndex.microseconds + TimedeltaIndex.nanoseconds + TimedeltaIndex.components + TimedeltaIndex.inferred_freq + +.. rubric:: `TimedeltaIndex` Conversion + +.. autosummary:: + :toctree: pandas_api/ + + TimedeltaIndex.as_unit + TimedeltaIndex.to_pytimedelta + TimedeltaIndex.round + TimedeltaIndex.floor + TimedeltaIndex.ceil + +.. rubric:: `TimedeltaIndex` Methods + +.. autosummary:: + :toctree: pandas_api/ + + TimedeltaIndex.mean diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 1855eb314b3..6bb214e3bd6 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -98,8 +98,8 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``assign`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``astype`` | P | | ``N``: from string to datetime or ``errors == | -| | | | "ignore"`` | +| ``astype`` | P | | ``N`` if from string to datetime/timedelta or | +| | | | ``errors == "ignore"`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``at_time`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/index.rst b/docs/source/modin/supported/index.rst index 97202d69290..2d7999c4954 100644 --- a/docs/source/modin/supported/index.rst +++ b/docs/source/modin/supported/index.rst @@ -16,6 +16,7 @@ To view the docs for the most recent release, check that you’re viewing the st dataframe_supported index_supported datetime_index_supported + timedelta_index_supported window_supported groupby_supported resampling_supported diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index ea78a3a0e68..331be4d0298 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -105,8 +105,8 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``asof`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``astype`` | P | | ``N``: from string to datetime or ``errors == | -| | | | "ignore"`` | +| ``astype`` | P | | ``N`` if from string to datetime/timedelta or | +| | | | ``errors == "ignore"`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``at_time`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/timedelta_index_supported.rst b/docs/source/modin/supported/timedelta_index_supported.rst new file mode 100644 index 00000000000..73abe530fd7 --- /dev/null +++ b/docs/source/modin/supported/timedelta_index_supported.rst @@ -0,0 +1,48 @@ +``pd.TimedeltaIndex`` supported APIs +==================================== + +The following table is structured as follows: The first column contains the method name. +The second column is a flag for whether or not there is an implementation in Snowpark for +the method in the left column. + +.. note:: + ``Y`` stands for yes, i.e., supports distributed implementation, ``N`` stands for no and API simply errors out, + ``P`` stands for partial (meaning some parameters may not be supported yet), and ``D`` stands for defaults to single + node pandas execution via UDF/Sproc. + +Attributes + ++-----------------------------+---------------------------------+----------------------------------------------------+ +| TimedeltaIndex attribute | Snowpark implemented? (Y/N/P/D) | Notes for current implementation | ++-----------------------------+---------------------------------+----------------------------------------------------+ +| ``days`` | N | | ++-----------------------------+---------------------------------+----------------------------------------------------+ +| ``seconds`` | N | | ++-----------------------------+---------------------------------+----------------------------------------------------+ +| ``microseconds`` | N | | ++-----------------------------+---------------------------------+----------------------------------------------------+ +| ``nanoseconds`` | N | | ++-----------------------------+---------------------------------+----------------------------------------------------+ +| ``components`` | N | | ++-----------------------------+---------------------------------+----------------------------------------------------+ +| ``inferred_freq`` | N | | ++-----------------------------+---------------------------------+----------------------------------------------------+ + + +Methods + ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| DataFrame method | Snowpark implemented? (Y/N/P/D) | Missing parameters | Notes for current implementation | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| ``as_unit`` | N | | | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| ``to_pytimedelta`` | N | | | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| ``round`` | N | | | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| ``floor`` | N | | | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| ``ceil`` | N | | | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ +| ``mean`` | N | | | ++-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+ diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index 100cb5470fc..9a381e5046a 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -43,6 +43,7 @@ Functions array_min array_position array_prepend + array_remove array_size array_slice array_sort diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 83c0eebb9c7..aad369a8b83 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -1608,6 +1608,9 @@ def with_query_block( new_query = project_statement([], name) + # note we do not propagate the query parameter of the child here, + # the query parameter will be propagate along with the definition during + # query generation stage. queries = child.queries[:-1] + [Query(sql=new_query)] # propagate the cte table referenced_ctes = {name}.union(child.referenced_ctes) diff --git a/src/snowflake/snowpark/_internal/compiler/query_generator.py b/src/snowflake/snowpark/_internal/compiler/query_generator.py index fdf8af9d4dd..d9220424097 100644 --- a/src/snowflake/snowpark/_internal/compiler/query_generator.py +++ b/src/snowflake/snowpark/_internal/compiler/query_generator.py @@ -64,7 +64,7 @@ def __init__( # NOTE: the dict used here is an ordered dict, all with query block definition is recorded in the # order of when the with query block is visited. The order is important to make sure the dependency # between the CTE definition is satisfied. - self.resolved_with_query_block: Dict[str, str] = {} + self.resolved_with_query_block: Dict[str, Query] = {} def generate_queries( self, logical_plans: List[LogicalPlan] @@ -209,7 +209,7 @@ def do_resolve_with_resolved_children( if logical_plan.name not in self.resolved_with_query_block: self.resolved_with_query_block[ logical_plan.name - ] = resolved_child.queries[-1].sql + ] = resolved_child.queries[-1] resolved_plan = self.plan_builder.with_query_block( logical_plan.name, diff --git a/src/snowflake/snowpark/_internal/compiler/utils.py b/src/snowflake/snowpark/_internal/compiler/utils.py index 579c3b8e5d6..273ebe0440a 100644 --- a/src/snowflake/snowpark/_internal/compiler/utils.py +++ b/src/snowflake/snowpark/_internal/compiler/utils.py @@ -227,6 +227,9 @@ def update_resolvable_node( # re-calculation of the sql query and snowflake plan node._sql_query = None node._snowflake_plan = None + # make sure we also clean up the cached _projection_in_str, so that + # the projection expression can be re-analyzed during code generation + node._projection_in_str = None node.analyzer = query_generator # update the pre_actions and post_actions for the select statement @@ -267,12 +270,26 @@ def update_resolvable_node( update_resolvable_node(node.snowflake_plan, query_generator) node.analyzer = query_generator + node.pre_actions = node._snowflake_plan.queries[:-1] + node.post_actions = node._snowflake_plan.post_actions + node._api_calls = node._snowflake_plan.api_calls + + if isinstance(node, SelectSnowflakePlan): + node.expr_to_alias.update(node._snowflake_plan.expr_to_alias) + node.df_aliased_col_name_to_real_col_name.update( + node._snowflake_plan.df_aliased_col_name_to_real_col_name + ) + node._query_params = [] + for query in node._snowflake_plan.queries: + if query.params: + node._query_params.extend(query.params) + elif isinstance(node, Selectable): node.analyzer = query_generator def get_snowflake_plan_queries( - plan: SnowflakePlan, resolved_with_query_blocks: Dict[str, str] + plan: SnowflakePlan, resolved_with_query_blocks: Dict[str, Query] ) -> Dict[PlanQueryType, List[Query]]: from snowflake.snowpark._internal.analyzer.analyzer_utils import cte_statement @@ -286,12 +303,16 @@ def get_snowflake_plan_queries( post_action_queries = copy.deepcopy(plan.post_actions) table_names = [] definition_queries = [] + final_query_params = [] for name, definition_query in resolved_with_query_blocks.items(): if name in plan.referenced_ctes: table_names.append(name) - definition_queries.append(definition_query) + definition_queries.append(definition_query.sql) + final_query_params.extend(definition_query.params) with_query = cte_statement(definition_queries, table_names) plan_queries[-1].sql = with_query + plan_queries[-1].sql + final_query_params.extend(plan_queries[-1].params) + plan_queries[-1].params = final_query_params return { PlanQueryType.QUERIES: plan_queries, diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 8f8b156132c..58c2ab8518c 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -5333,6 +5333,56 @@ def array_append(array: ColumnOrName, element: ColumnOrName) -> Column: return builtin("array_append")(a, e) +def array_remove(array: ColumnOrName, element: ColumnOrLiteral) -> Column: + """Given a source ARRAY, returns an ARRAY with elements of the specified value removed. + + Args: + array: name of column containing array. + element: element to be removed from the array. If the element is a VARCHAR, it needs + to be casted into VARIANT data type. + + Examples:: + >>> from snowflake.snowpark.types import VariantType + >>> df = session.create_dataframe([([1, '2', 3.1, 1, 1],)], ['data']) + >>> df.select(array_remove(df.data, 1).alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |[ | + | "2", | + | 3.1 | + |] | + ------------- + + + >>> df.select(array_remove(df.data, lit('2').cast(VariantType())).alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |[ | + | 1, | + | 3.1, | + | 1, | + | 1 | + |] | + ------------- + + + >>> df.select(array_remove(df.data, None).alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |NULL | + ------------- + + + See Also: + - `ARRAY `_ for more details on semi-structured arrays. + """ + a = _to_col_if_str(array, "array_remove") + return builtin("array_remove")(a, element) + + def array_cat(array1: ColumnOrName, array2: ColumnOrName) -> Column: """Returns the concatenation of two ARRAYs. diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 975289684cf..c4eb07d9589 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -63,7 +63,6 @@ SparseDtype, StringDtype, Timedelta, - TimedeltaIndex, Timestamp, UInt8Dtype, UInt16Dtype, @@ -156,6 +155,7 @@ from snowflake.snowpark.modin.plugin.extensions.pd_overrides import ( # isort: skip # noqa: E402,F401 Index, DatetimeIndex, + TimedeltaIndex, ) import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401 diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py index 7d03940b7e0..a0ca357c59b 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from types import MappingProxyType +import numpy as np import pandas as native_pd from pandas._typing import Callable, Scalar @@ -236,10 +237,34 @@ def _compute_subtraction_between_snowpark_timestamp_columns( "rmul": "mul", "rsub": "sub", "rmod": "mod", + "__rand__": "__and__", + "__ror__": "__or__", } ) +def _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype: DataType, second_datatype: DataType +) -> bool: + """ + Whether the binary operation is between two timedeltas, or between timedelta and null. + + Args: + first_datatype: First datatype + second_datatype: Second datatype + + Returns: + bool: Whether op is between two timedeltas or between timedelta and null. + """ + return ( + isinstance(first_datatype, TimedeltaType) + and isinstance(second_datatype, (TimedeltaType, NullType)) + ) or ( + isinstance(first_datatype, (TimedeltaType, NullType)) + and isinstance(second_datatype, TimedeltaType) + ) + + def compute_binary_op_between_snowpark_columns( op: str, first_operand: SnowparkColumn, @@ -274,6 +299,7 @@ def compute_binary_op_between_snowpark_columns( ) binary_op_result_column = None + snowpark_pandas_type = None # some operators and the data types have to be handled specially to align with pandas # However, it is difficult to fail early if the arithmetic operator is not compatible @@ -290,7 +316,18 @@ def compute_binary_op_between_snowpark_columns( and isinstance(second_datatype(), TimestampType) ): binary_op_result_column = dateadd("ns", first_operand, second_operand) - elif op == "add" and ( + elif op in ( + "add", + "sub", + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + "floordiv", + "truediv", + ) and ( ( isinstance(first_datatype(), TimedeltaType) and isinstance(second_datatype(), NullType) @@ -315,16 +352,66 @@ def compute_binary_op_between_snowpark_columns( # Timedelta - Timestamp doesn't make sense. Raise the same error # message as pandas. raise TypeError("bad operand type for unary -: 'DatetimeArray'") - elif isinstance(first_datatype(), TimedeltaType) or isinstance( - second_datatype(), TimedeltaType + elif op == "mod" and _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype(), second_datatype() + ): + binary_op_result_column = compute_modulo_between_snowpark_columns( + first_operand, first_datatype(), second_operand, second_datatype() + ) + snowpark_pandas_type = TimedeltaType() + elif op == "pow" and _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype(), second_datatype() + ): + raise TypeError("unsupported operand type for **: Timedelta") + elif op == "__or__" and _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype(), second_datatype() + ): + raise TypeError("unsupported operand type for |: Timedelta") + elif op == "__and__" and _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype(), second_datatype() + ): + raise TypeError("unsupported operand type for &: Timedelta") + elif ( + op in ("add", "sub") + and isinstance(first_datatype(), TimedeltaType) + and isinstance(second_datatype(), TimedeltaType) + ): + snowpark_pandas_type = TimedeltaType() + elif op == "mul" and _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype(), second_datatype() + ): + raise np.core._exceptions._UFuncBinaryResolutionError( # type: ignore[attr-defined] + np.multiply, (np.dtype("timedelta64[ns]"), np.dtype("timedelta64[ns]")) + ) + elif _op_is_between_two_timedeltas_or_timedelta_and_null( + first_datatype(), second_datatype() + ) and op in ("eq", "ne", "gt", "ge", "lt", "le", "truediv"): + # These operations, when done between timedeltas, work without any + # extra handling in `snowpark_pandas_type` or `binary_op_result_column`. + # They produce outputs that are not timedeltas (e.g. numbers for floordiv + # and truediv, and bools for the comparisons). + pass + elif ( + # equal_null and floordiv for timedelta also work without special + # handling, but we need to exclude them from the above case so we catch + # them in an `elif` clause further down. + op not in ("equal_null", "floordiv") + and ( + ( + isinstance(first_datatype(), TimedeltaType) + and not isinstance(second_datatype(), TimedeltaType) + ) + or ( + not isinstance(first_datatype(), TimedeltaType) + and isinstance(second_datatype(), TimedeltaType) + ) + ) ): # We don't support these cases yet. - # TODO(SNOW-1637101, SNOW-1637102): Support these cases. + # TODO(SNOW-1637102): Support this case. ErrorMessage.not_implemented( - f"Snowpark pandas does not yet support the binary operation {op} with timedelta types." + f"Snowpark pandas does not yet support the binary operation {op} with a Timedelta column and a non-Timedelta column." ) - elif op == "truediv": - binary_op_result_column = first_operand / second_operand elif op == "floordiv": binary_op_result_column = floor(first_operand / second_operand) elif op == "mod": @@ -335,9 +422,9 @@ def compute_binary_op_between_snowpark_columns( binary_op_result_column = compute_power_between_snowpark_columns( first_operand, second_operand ) - elif op in ["__or__", "__ror__"]: + elif op == "__or__": binary_op_result_column = first_operand | second_operand - elif op in ["__and__", "__rand__"]: + elif op == "__and__": binary_op_result_column = first_operand & second_operand elif ( op == "add" @@ -370,6 +457,8 @@ def compute_binary_op_between_snowpark_columns( pandas_lit(""), ) elif op == "equal_null": + # TODO(SNOW-1641716): In Snowpark pandas, generally use this equal_null + # with type checking intead of snowflake.snowpark.functions.equal_null. if not are_equal_types(first_datatype(), second_datatype()): binary_op_result_column = pandas_lit(False) else: @@ -409,7 +498,7 @@ def compute_binary_op_between_snowpark_columns( return SnowparkPandasColumn( snowpark_column=binary_op_result_column, - snowpark_pandas_type=None, + snowpark_pandas_type=snowpark_pandas_type, ) @@ -423,6 +512,10 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool: Returns: True if given types are equal, False otherwise. """ + if isinstance(type1, TimedeltaType) and not isinstance(type2, TimedeltaType): + return False + if isinstance(type2, TimedeltaType) and not isinstance(type1, TimedeltaType): + return False if isinstance(type1, _IntegralType) and isinstance(type2, _IntegralType): return True if isinstance(type1, _FractionalType) and isinstance(type2, _FractionalType): diff --git a/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py index 1fd252bf7e0..261903be0f7 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py @@ -267,6 +267,8 @@ def to_pandas(cls, s: DataType) -> Union[np.dtype, ExtensionDtype]: return np.dtype("int64") if s.scale == 0 else np.dtype("float64") if isinstance(s, TimestampType): return np.dtype("datetime64[ns]") + if isinstance(s, TimedeltaType): + return np.dtype("timedelta64[ns]") # We also need to treat parameterized types correctly if isinstance(s, (StringType, ArrayType, MapType, GeographyType)): return np.dtype(np.object_) @@ -316,12 +318,12 @@ def column_astype( isinstance(from_sf_type, TimestampType) and from_sf_type.tz == TimestampTimeZone.LTZ ): - # treat TIMESTAMPT_LTZ columns as same as TIMESTAMPT_TZ + # treat TIMESTAMP_LTZ columns as same as TIMESTAMP_TZ curr_col = builtin("to_timestamp_tz")(curr_col) if isinstance(to_sf_type, TimestampType): assert to_sf_type.tz != TimestampTimeZone.LTZ, ( - "Cast to TIMESTAMPT_LTZ is not supported in astype since " + "Cast to TIMESTAMP_LTZ is not supported in astype since " "Snowpark pandas API maps tz aware datetime to TIMESTAMP_TZ" ) # convert to timestamp diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 31904a58b65..7e6336c397e 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -1730,7 +1730,6 @@ def set_index_from_series( Returns: The new SnowflakeQueryCompiler after the set_index operation """ - self._raise_not_implemented_error_for_timedelta() assert ( len(key._modin_frame.data_column_pandas_labels) == 1 @@ -1761,6 +1760,7 @@ def set_index_from_series( new_index_ids = result_column_mapper.map_right_quoted_identifiers( other_frame.data_column_snowflake_quoted_identifiers ) + new_index_snowpark_types = other_frame.cached_data_column_snowpark_pandas_types if append: new_index_labels = ( new_internal_frame.index_column_pandas_labels + new_index_labels @@ -1769,6 +1769,10 @@ def set_index_from_series( new_internal_frame.index_column_snowflake_quoted_identifiers + new_index_ids ) + new_index_snowpark_types = ( + self_frame.cached_index_column_snowpark_pandas_types + + new_index_snowpark_types + ) new_internal_frame = InternalFrame.create( ordered_dataframe=new_internal_frame.ordered_dataframe, data_column_pandas_labels=self_frame.data_column_pandas_labels, @@ -1778,8 +1782,8 @@ def set_index_from_series( ), index_column_pandas_labels=new_index_labels, index_column_snowflake_quoted_identifiers=new_index_ids, - data_column_types=None, - index_column_types=None, + data_column_types=self_frame.cached_data_column_snowpark_pandas_types, + index_column_types=new_index_snowpark_types, ) return SnowflakeQueryCompiler(new_internal_frame) @@ -5882,7 +5886,6 @@ def set_index_from_columns( Returns: A new QueryCompiler instance with updated index. """ - self._raise_not_implemented_error_for_timedelta() index_column_pandas_labels = keys index_column_snowflake_quoted_identifiers = [] @@ -5910,6 +5913,16 @@ def set_index_from_columns( self._modin_frame.data_column_snowflake_quoted_identifiers ) + id_to_type = ( + self._modin_frame.snowflake_quoted_identifier_to_snowpark_pandas_type + ) + index_column_snowpark_pandas_types = [ + id_to_type.get(id) for id in index_column_snowflake_quoted_identifiers + ] + data_column_snowpark_pandas_types = [ + id_to_type.get(id) for id in data_column_snowflake_quoted_identifiers + ] + # Generate aliases for new index columns if # 1. 'keys' are also kept as data columns, or # 2. 'keys' have duplicates. @@ -5944,6 +5957,10 @@ def set_index_from_columns( self._modin_frame.index_column_snowflake_quoted_identifiers + index_column_snowflake_quoted_identifiers ) + index_column_snowpark_pandas_types = ( + self._modin_frame.cached_index_column_snowpark_pandas_types + + index_column_snowpark_pandas_types + ) frame = InternalFrame.create( ordered_dataframe=ordered_dataframe, @@ -5952,8 +5969,8 @@ def set_index_from_columns( data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names, data_column_pandas_labels=data_column_pandas_labels, data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, - data_column_types=None, - index_column_types=None, + data_column_types=data_column_snowpark_pandas_types, + index_column_types=index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(frame) @@ -8778,7 +8795,15 @@ def astype( to_sf_type = TypeMapper.to_snowflake(to_dtype) from_dtype = col_dtypes_curr[label] from_sf_type = self._modin_frame.get_snowflake_type(id) - if is_astype_type_error(from_sf_type, to_sf_type): + if isinstance(from_sf_type, StringType) and isinstance( + to_sf_type, TimedeltaType + ): + # Raise NotImplementedError as there is no Snowflake SQL function converting + # string (e.g. 1 day, 3 hours, 2 minutes) to Timedelta + ErrorMessage.not_implemented( + f"dtype {pandas_dtype(from_dtype)} cannot be converted to {pandas_dtype(to_dtype)}" + ) + elif is_astype_type_error(from_sf_type, to_sf_type): raise TypeError( f"dtype {pandas_dtype(from_dtype)} cannot be converted to {pandas_dtype(to_dtype)}" ) @@ -10255,13 +10280,16 @@ def drop( data_column_labels = [] data_column_identifiers = [] - for label, identifiers in zip( + data_column_snowpark_pandas_types = [] + for label, identifiers, type in zip( frame.data_column_pandas_labels, frame.data_column_snowflake_quoted_identifiers, + frame.cached_data_column_snowpark_pandas_types, ): if label not in data_column_labels_to_drop: data_column_labels.append(label) data_column_identifiers.append(identifiers) + data_column_snowpark_pandas_types.append(type) frame = InternalFrame.create( ordered_dataframe=frame.ordered_dataframe, @@ -10270,8 +10298,8 @@ def drop( data_column_pandas_labels=data_column_labels, data_column_snowflake_quoted_identifiers=data_column_identifiers, data_column_pandas_index_names=frame.data_column_pandas_index_names, - data_column_types=None, - index_column_types=None, + data_column_types=data_column_snowpark_pandas_types, + index_column_types=frame.cached_index_column_snowpark_pandas_types, ) frame = frame.select_active_columns() @@ -17257,17 +17285,27 @@ def compare( other._modin_frame.data_column_snowflake_quoted_identifiers, self._modin_frame.data_column_pandas_labels, ): + left_identiifer = result_column_mapper.left_quoted_identifiers_map[ + left_identifier + ] + right_identifier = result_column_mapper.right_quoted_identifiers_map[ + right_identifier + ] + op_result = compute_binary_op_between_snowpark_columns( + op="equal_null", + first_operand=col(left_identifier), + first_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), left_identiifer + ), + second_operand=col(right_identifier), + second_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), right_identifier + ), + ) binary_op_result = binary_op_result.append_column( str(left_pandas_label) + "_comparison_result", - col( - result_column_mapper.left_quoted_identifiers_map[left_identifier] - ).equal_null( - col( - result_column_mapper.right_quoted_identifiers_map[ - right_identifier - ] - ) - ), + op_result.snowpark_column, + op_result.snowpark_pandas_type, ) """ >>> SnowflakeQueryCompiler(binary_op_result).to_pandas() @@ -17345,28 +17383,48 @@ def compare( new_pandas_labels = [] new_values = [] column_index_tuples = [] + column_types = [] for ( pandas_column_value, pandas_label, left_identifier, right_identifier, column_only_contains_matches, + left_type, + right_type, ) in zip( self.columns, filtered_binary_op_result.data_column_pandas_labels, self._modin_frame.data_column_snowflake_quoted_identifiers, other._modin_frame.data_column_snowflake_quoted_identifiers, all_rows_match_frame.iloc[:, 0].values, + self._modin_frame.cached_data_column_snowpark_pandas_types, + other._modin_frame.cached_data_column_snowpark_pandas_types, ): # Drop columns that only contain matches. if column_only_contains_matches: continue - cols_equal = col( - result_column_mapper.left_quoted_identifiers_map[left_identifier] - ).equal_null( - col(result_column_mapper.right_quoted_identifiers_map[right_identifier]) - ) + left_mappped_identifier = result_column_mapper.left_quoted_identifiers_map[ + left_identifier + ] + right_mapped_identifier = result_column_mapper.right_quoted_identifiers_map[ + right_identifier + ] + + cols_equal = compute_binary_op_between_snowpark_columns( + op="equal_null", + first_operand=col(left_mappped_identifier), + first_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), + left_mappped_identifier, + ), + second_operand=col(right_mapped_identifier), + second_datatype=functools.partial( + lambda col: result_frame.get_snowflake_type(col), + right_mapped_identifier, + ), + ).snowpark_column # Add a column containing the values from `self`, but replace # matching values with null. @@ -17375,11 +17433,7 @@ def compare( iff( condition=cols_equal, expr1=pandas_lit(np.nan), - expr2=col( - result_column_mapper.left_quoted_identifiers_map[ - left_identifier - ] - ), + expr2=col(left_mappped_identifier), ) ) @@ -17390,11 +17444,7 @@ def compare( iff( condition=cols_equal, expr1=pandas_lit(np.nan), - expr2=col( - result_column_mapper.right_quoted_identifiers_map[ - right_identifier - ] - ), + expr2=col(right_mapped_identifier), ) ) @@ -17408,8 +17458,13 @@ def compare( column_index_tuples.append((pandas_column_value, "self")) column_index_tuples.append((pandas_column_value, "other")) + column_types.append(left_type) + column_types.append(right_type) + result = SnowflakeQueryCompiler( - filtered_binary_op_result.project_columns(new_pandas_labels, new_values) + filtered_binary_op_result.project_columns( + new_pandas_labels, new_values, column_types + ) ).set_columns( # TODO(SNOW-1510921): fix the levels and inferred_type of the # result's MultiIndex once we can pass the levels correctly through diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 21692a228ec..bbd415536af 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -42,6 +42,7 @@ is_integer_dtype, is_numeric_dtype, is_object_dtype, + is_timedelta64_dtype, pandas_dtype, ) from pandas.core.dtypes.inference import is_hashable @@ -112,21 +113,30 @@ def __new__( from snowflake.snowpark.modin.plugin.extensions.datetime_index import ( DatetimeIndex, ) + from snowflake.snowpark.modin.plugin.extensions.timedelta_index import ( + TimedeltaIndex, + ) if query_compiler: dtype = query_compiler.index_dtypes[0] - if dtype == np.dtype("datetime64[ns]"): + if is_datetime64_any_dtype(dtype): return DatetimeIndex(query_compiler=query_compiler) + if is_timedelta64_dtype(dtype): + return TimedeltaIndex(query_compiler=query_compiler) elif isinstance(data, BasePandasDataset): if data.ndim != 1: raise ValueError("Index data must be 1 - dimensional") dtype = data.dtype - if dtype == np.dtype("datetime64[ns]"): - return DatetimeIndex(data, dtype, copy, name, tupleize_cols) + if is_datetime64_any_dtype(dtype): + return DatetimeIndex(data, dtype=dtype, copy=copy, name=name) + if is_timedelta64_dtype(dtype): + return TimedeltaIndex(data, dtype=dtype, copy=copy, name=name) else: index = native_pd.Index(data, dtype, copy, name, tupleize_cols) if isinstance(index, native_pd.DatetimeIndex): return DatetimeIndex(data) + if isinstance(index, native_pd.TimedeltaIndex): + return TimedeltaIndex(data) return object.__new__(cls) def __init__( @@ -252,9 +262,13 @@ def __getattr__(self, key: str) -> Any: def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) - return self.__constructor__( - self.to_series().reset_index(drop=True).__getattr__(method)(other) - ) + series = self.to_series().reset_index(drop=True).__getattr__(method)(other) + qc = series._query_compiler + qc = qc.set_index_from_columns(qc.columns, include_index=False) + # Use base constructor to ensure that the correct type is returned. + idx = Index(query_compiler=qc) + idx.name = series.name + return idx def _unary_ops(self, method: str) -> Index: return self.__constructor__( diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py index 3515baaee3a..5d61bc95694 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py @@ -44,9 +44,13 @@ DatetimeIndex, ) from snowflake.snowpark.modin.plugin.extensions.index import Index # noqa: F401 +from snowflake.snowpark.modin.plugin.extensions.timedelta_index import ( # noqa: F401 + TimedeltaIndex, +) register_pd_accessor("Index")(Index) register_pd_accessor("DatetimeIndex")(DatetimeIndex) +register_pd_accessor("TimedeltaIndex")(TimedeltaIndex) @_inherit_docstrings(native_pd.read_csv, apilink="pandas.read_csv") diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py new file mode 100644 index 00000000000..7facf4acefd --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -0,0 +1,388 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Module houses ``TimedeltaIndex`` class, that is distributed version of +``pandas.TimedeltaIndex``. +""" + +from __future__ import annotations + +import numpy as np +import pandas as native_pd +from pandas._libs import lib +from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable +from pandas.core.dtypes.common import is_timedelta64_dtype + +from snowflake.snowpark.modin.pandas import DataFrame, Series +from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, +) +from snowflake.snowpark.modin.plugin.extensions.index import Index +from snowflake.snowpark.modin.plugin.utils.error_message import ( + timedelta_index_not_implemented, +) + +_CONSTRUCTOR_DEFAULTS = { + "unit": lib.no_default, + "freq": lib.no_default, + "dtype": None, + "copy": False, + "name": None, +} + + +class TimedeltaIndex(Index): + + # Equivalent index type in native pandas + _NATIVE_INDEX_TYPE = native_pd.TimedeltaIndex + + def __new__(cls, *args, **kwargs): + """ + Create new instance of TimedeltaIndex. This overrides behavior of Index.__new__. + Args: + *args: arguments. + **kwargs: keyword arguments. + + Returns: + New instance of TimedeltaIndex. + """ + return object.__new__(cls) + + def __init__( + self, + data: ArrayLike | native_pd.Index | Series | None = None, + unit: str | lib.NoDefault = _CONSTRUCTOR_DEFAULTS["unit"], + freq: Frequency | lib.NoDefault = _CONSTRUCTOR_DEFAULTS["freq"], + dtype: Dtype | None = _CONSTRUCTOR_DEFAULTS["dtype"], + copy: bool = _CONSTRUCTOR_DEFAULTS["copy"], + name: Hashable | None = _CONSTRUCTOR_DEFAULTS["name"], + query_compiler: SnowflakeQueryCompiler = None, + ) -> None: + """ + Immutable Index of timedelta64 data. + + Represented internally as int64, and scalars returned Timedelta objects. + + Parameters + ---------- + data : array-like (1-dimensional), optional + Optional timedelta-like data to construct index with. + unit : {'D', 'h', 'm', 's', 'ms', 'us', 'ns'}, optional + The unit of ``data``. + + .. deprecated:: 2.2.0 + Use ``pd.to_timedelta`` instead. + + freq : str or pandas offset object, optional + One of pandas date offset strings or corresponding objects. The string + ``'infer'`` can be passed in order to set the frequency of the index as + the inferred frequency upon creation. + dtype : numpy.dtype or str, default None + Valid ``numpy`` dtypes are ``timedelta64[ns]``, ``timedelta64[us]``, + ``timedelta64[ms]``, and ``timedelta64[s]``. + copy : bool + Make a copy of input array. + name : object + Name to be stored in the index. + + Examples + -------- + >>> pd.TimedeltaIndex(['0 days', '1 days', '2 days', '3 days', '4 days']) + TimedeltaIndex(['0 days', '1 days', '2 days', '3 days', '4 days'], dtype='timedelta64[ns]') + + We can also let pandas infer the frequency when possible. + + >>> pd.TimedeltaIndex(np.arange(5) * 24 * 3600 * 1e9, freq='infer') + TimedeltaIndex(['0 days', '1 days', '2 days', '3 days', '4 days'], dtype='timedelta64[ns]') + """ + if query_compiler: + # Raise error if underlying type is not a Timedelta type. + current_dtype = query_compiler.index_dtypes[0] + if not is_timedelta64_dtype(current_dtype): + raise ValueError( + f"TimedeltaIndex can only be created from a query compiler with TimedeltaType, found {current_dtype}" + ) + kwargs = { + "unit": unit, + "freq": freq, + "dtype": dtype, + "copy": copy, + "name": name, + } + self._init_index(data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs) + + @timedelta_index_not_implemented() + @property + def days(self) -> Index: + """ + Number of days for each element. + + Returns + ------- + An Index with the days component of the timedelta. + + Examples + -------- + >>> idx = pd.to_timedelta(["0 days", "10 days", "20 days"]) # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['0 days', '10 days', '20 days'], + dtype='timedelta64[ns]', freq=None) + >>> idx.days # doctest: +SKIP + Index([0, 10, 20], dtype='int64') + """ + + @timedelta_index_not_implemented() + @property + def seconds(self) -> Index: + """ + Number of seconds (>= 0 and less than 1 day) for each element. + + Returns + ------- + An Index with seconds component of the timedelta. + + Examples + -------- + >>> idx = pd.to_timedelta([1, 2, 3], unit='s') # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['0 days 00:00:01', '0 days 00:00:02', '0 days 00:00:03'], + dtype='timedelta64[ns]', freq=None) + >>> idx.seconds # doctest: +SKIP + Index([1, 2, 3], dtype='int32') + """ + + @timedelta_index_not_implemented() + @property + def microseconds(self) -> Index: + """ + Number of microseconds (>= 0 and less than 1 second) for each element. + + Returns + ------- + An Index with microseconds component of the timedelta. + + Examples + -------- + >>> idx = pd.to_timedelta([1, 2, 3], unit='us') # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['0 days 00:00:00.000001', '0 days 00:00:00.000002', + '0 days 00:00:00.000003'], + dtype='timedelta64[ns]', freq=None) + >>> idx.microseconds # doctest: +SKIP + Index([1, 2, 3], dtype='int32') + """ + + @timedelta_index_not_implemented() + @property + def nanoseconds(self) -> Index: + """ + Number of nonoseconds (>= 0 and less than 1 microsecond) for each element. + + Returns + ------- + An Index with nanoseconds compnent of the timedelta. + + Examples + -------- + >>> idx = pd.to_timedelta([1, 2, 3], unit='ns') # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['0 days 00:00:00.000000001', '0 days 00:00:00.000000002', + '0 days 00:00:00.000000003'], + dtype='timedelta64[ns]', freq=None) + >>> idx.nanoseconds # doctest: +SKIP + Index([1, 2, 3], dtype='int32') + """ + + @timedelta_index_not_implemented() + @property + def components(self) -> DataFrame: + """ + Return a DataFrame of the individual resolution components of the Timedeltas. + + The components (days, hours, minutes seconds, milliseconds, microseconds, + nanoseconds) are returned as columns in a DataFrame. + + Returns + ------- + A DataFrame + + Examples + -------- + >>> idx = pd.to_timedelta(['1 day 3 min 2 us 42 ns']) # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['1 days 00:03:00.000002042'], + dtype='timedelta64[ns]', freq=None) + >>> idx.components # doctest: +SKIP + days hours minutes seconds milliseconds microseconds nanoseconds + 0 1 0 3 0 0 2 42 + """ + + @timedelta_index_not_implemented() + @property + def inferred_freq(self) -> str | None: + """ + Tries to return a string representing a frequency generated by infer_freq. + + Returns None if it can't autodetect the frequency. + + Examples + -------- + >>> idx = pd.to_timedelta(["0 days", "10 days", "20 days"]) # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['0 days', '10 days', '20 days'], + dtype='timedelta64[ns]', freq=None) + >>> idx.inferred_freq # doctest: +SKIP + '10D' + """ + + @timedelta_index_not_implemented() + def round(self, freq: Frequency) -> TimedeltaIndex: + """ + Perform round operation on the data to the specified `freq`. + + Parameters + ---------- + freq : str or Offset + The frequency level to round the index to. Must be a fixed + frequency like 'S' (second) not 'ME' (month end). See + frequency aliases for a list of possible `freq` values. + + Returns + ------- + TimedeltaIndex with round values. + + Raises + ------ + ValueError if the `freq` cannot be converted. + """ + + @timedelta_index_not_implemented() + def floor(self, freq: Frequency) -> TimedeltaIndex: + """ + Perform floor operation on the data to the specified `freq`. + + Parameters + ---------- + freq : str or Offset + The frequency level to floor the index to. Must be a fixed + frequency like 'S' (second) not 'ME' (month end). See + frequency aliases for a list of possible `freq` values. + + Returns + ------- + TimedeltaIndex with floor values. + + Raises + ------ + ValueError if the `freq` cannot be converted. + """ + + @timedelta_index_not_implemented() + def ceil(self, freq: Frequency) -> TimedeltaIndex: + """ + Perform ceil operation on the data to the specified `freq`. + + Parameters + ---------- + freq : str or Offset + The frequency level to ceil the index to. Must be a fixed + frequency like 'S' (second) not 'ME' (month end). See + frequency aliases for a list of possible `freq` values. + + Returns + ------- + TimedeltaIndex with ceil values. + + Raises + ------ + ValueError if the `freq` cannot be converted. + """ + + @timedelta_index_not_implemented() + def to_pytimedelta(self) -> np.ndarray: + """ + Return an ndarray of datetime.timedelta objects. + + Returns + ------- + numpy.ndarray + + Examples + -------- + >>> idx = pd.to_timedelta([1, 2, 3], unit='D') # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['1 days', '2 days', '3 days'], + dtype='timedelta64[ns]', freq=None) + >>> idx.to_pytimedelta() # doctest: +SKIP + array([datetime.timedelta(days=1), datetime.timedelta(days=2), + datetime.timedelta(days=3)], dtype=object) + """ + + @timedelta_index_not_implemented() + def mean( + self, *, skipna: bool = True, axis: AxisInt | None = 0 + ) -> native_pd.Timestamp: + """ + Return the mean value of the Array. + + Parameters + ---------- + skipna : bool, default True + Whether to ignore any NaT elements. + axis : int, optional, default 0 + + Returns + ------- + scalar Timestamp + + See Also + -------- + numpy.ndarray.mean : Returns the average of array elements along a given axis. + Series.mean : Return the mean value in a Series. + + Notes + ----- + mean is only defined for Datetime and Timedelta dtypes, not for Period. + """ + + @timedelta_index_not_implemented() + def as_unit(self, unit: str) -> TimedeltaIndex: + """ + Convert to a dtype with the given unit resolution. + + Parameters + ---------- + unit : {'s', 'ms', 'us', 'ns'} + + Returns + ------- + DatetimeIndex + + Examples + -------- + >>> idx = pd.to_timedelta(['1 day 3 min 2 us 42 ns']) # doctest: +SKIP + >>> idx # doctest: +SKIP + TimedeltaIndex(['1 days 00:03:00.000002042'], + dtype='timedelta64[ns]', freq=None) + >>> idx.as_unit('s') # doctest: +SKIP + TimedeltaIndex(['1 days 00:03:00'], dtype='timedelta64[s]', freq=None) + """ diff --git a/src/snowflake/snowpark/modin/plugin/utils/error_message.py b/src/snowflake/snowpark/modin/plugin/utils/error_message.py index 1e832450579..7fc86152c63 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/error_message.py +++ b/src/snowflake/snowpark/modin/plugin/utils/error_message.py @@ -143,6 +143,10 @@ def raise_not_implemented_method_error( decorating_functions=False, attribute_prefix="DatetimeIndex" ) +timedelta_index_not_implemented = _make_not_implemented_decorator( + decorating_functions=False, attribute_prefix="TimedeltaIndex" +) + pandas_module_level_function_not_implemented = _make_not_implemented_decorator( decorating_functions=True, attribute_prefix="pd" ) diff --git a/tests/integ/modin/binary/test_timedelta.py b/tests/integ/modin/binary/test_timedelta.py index 632243664d7..d9fa20b1a40 100644 --- a/tests/integ/modin/binary/test_timedelta.py +++ b/tests/integ/modin/binary/test_timedelta.py @@ -11,8 +11,10 @@ import pytest import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( + assert_series_equal, assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_dfs, create_test_series, @@ -109,12 +111,37 @@ def timestamp_scalar(request): return request.param +@pytest.fixture( + params=[ + pd.Timedelta("10 days 23:59:59.123456789"), + datetime.timedelta(microseconds=1), + datetime.timedelta(microseconds=2), + pd.Timedelta(nanoseconds=1), + pd.Timedelta(nanoseconds=2), + pd.Timedelta(nanoseconds=3), + pd.Timedelta(days=1), + pd.Timedelta(days=1, hours=1), + pd.Timedelta(days=10), + ] +) +def timedelta_scalar_positive(request): + return request.param + + @pytest.fixture( params=[ pd.Timedelta("10 days 23:59:59.123456789"), pd.Timedelta("-10 days 23:59:59.123456789"), datetime.timedelta(days=-10, hours=23), datetime.timedelta(microseconds=1), + datetime.timedelta(microseconds=2), + pd.Timedelta(nanoseconds=1), + pd.Timedelta(nanoseconds=2), + pd.Timedelta(nanoseconds=3), + pd.Timedelta(days=1), + pd.Timedelta(days=1, hours=1), + pd.Timedelta(days=10), + pd.Timedelta(days=-1), ] ) def timedelta_scalar(request): @@ -134,6 +161,36 @@ def timedelta_dataframes_1() -> tuple[pd.DataFrame, native_pd.DataFrame]: ) +@pytest.fixture +def timedelta_dataframes_postive_no_nulls_1_2x2() -> tuple[ + pd.DataFrame, native_pd.DataFrame +]: + return create_test_dfs( + [ + [pd.Timedelta(days=1), pd.Timedelta(days=4)], + [ + pd.Timedelta(days=2), + pd.Timedelta(days=3), + ], + ] + ) + + +@pytest.fixture +def timedelta_dataframes_with_negatives_no_nulls_1_2x2() -> tuple[ + pd.DataFrame, native_pd.DataFrame +]: + return create_test_dfs( + [ + [pd.Timedelta(days=1), pd.Timedelta(days=4)], + [ + pd.Timedelta(days=2), + pd.Timedelta(days=-3), + ], + ] + ) + + @pytest.fixture def timedelta_series_1() -> tuple[pd.Series, native_pd.Series]: return create_test_series( @@ -148,10 +205,332 @@ def timedelta_series_1() -> tuple[pd.Series, native_pd.Series]: ) +@pytest.fixture +def timedelta_series_positive_no_nulls_1_length_6() -> tuple[ + pd.Series, native_pd.Series +]: + return create_test_series( + [ + pd.Timedelta(days=1), + pd.Timedelta(days=2), + pd.Timedelta(days=3), + pd.Timedelta(days=4), + pd.Timedelta(days=5), + pd.Timedelta(days=6), + ] + ) + + +@pytest.fixture +def timedelta_series_no_nulls_2_length_6() -> tuple[pd.Series, native_pd.Series]: + return create_test_series( + [ + pd.Timedelta(microseconds=7), + pd.Timedelta(hours=6, minutes=5), + pd.Timedelta(hours=4, minutes=3), + pd.Timedelta(hours=2, minutes=1), + pd.Timedelta(hours=8, minutes=9), + pd.Timedelta(hours=9, minutes=8), + ] + ) + + +@pytest.fixture +def timedelta_series_no_nulls_3_length_2() -> tuple[pd.Series, native_pd.Series]: + return create_test_series( + [ + pd.Timedelta(microseconds=7), + pd.Timedelta(hours=6, minutes=5), + ] + ) + + +@pytest.fixture( + params=[ + "sub", + "rsub", + "add", + "radd", + "div", + "rdiv", + "truediv", + "rtruediv", + "floordiv", + "rfloordiv", + "mod", + "rmod", + "eq", + "ne", + "gt", + "lt", + "ge", + "le", + ] +) +def op_between_timedeltas(request) -> list[str]: + """Valid operations between timedeltas.""" + return request.param + + +class TestInvalid: + """ + Test invalid binary operations, e.g. subtracting a timestamp from a timedelta. + + For simplicity, check these cases for operations between dataframes and + scalars only. + """ + + @sql_count_checker(query_count=0) + def test_timedelta_scalar_minus_timestamp_dataframe(self): + eval_snowpark_pandas_result( + *create_test_dfs([datetime.datetime(year=2024, month=8, day=21)]), + lambda df: pd.Timedelta(1) - df, + expect_exception=True, + expect_exception_type=TypeError, + expect_exception_match=re.escape( + "bad operand type for unary -: 'DatetimeArray" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.parametrize( + "operation,error_symbol", + [("__or__", "|"), ("__ror__", "|"), ("__and__", "&"), ("__rand__", "&")], + ) + def test_timedelta_dataframe_bitwise_operation_with_timedelta_scalar( + self, operation, timedelta_dataframes_1, error_symbol + ): + eval_snowpark_pandas_result( + *timedelta_dataframes_1, + lambda df: getattr(df, operation)(pd.Timedelta(2)), + expect_exception=True, + # pandas exception depends on the input types and is something like + # "unsupported operand type(s) for &: 'Timedelta' and 'TimedeltaArray'", + # but Snowpwark pandas always gives the same exception. + assert_exception_equal=False, + expect_exception_type=TypeError, + expect_exception_match=re.escape( + f"unsupported operand type for {error_symbol}: Timedelta" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.parametrize("operation", ["pow", "rpow"]) + def test_timedelta_dataframe_exponentiation_with_timedelta_scalar( + self, operation, timedelta_dataframes_1 + ): + eval_snowpark_pandas_result( + *timedelta_dataframes_1, + lambda df: getattr(df, operation)(pd.Timedelta(2)), + expect_exception=True, + # pandas exception depends on the input types and is something + # like "cannot perform __rpow__ with this index type: + # TimedeltaArray", but Snowpwark pandas always gives the same + # exception. + assert_exception_equal=False, + expect_exception_type=TypeError, + expect_exception_match=re.escape( + "unsupported operand type for **: Timedelta" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.parametrize("operation", ["mul", "rmul"]) + def test_timedelta_dataframe_multiplied_by_timedelta_scalar_invalid( + self, operation, timedelta_dataframes_1 + ): + eval_snowpark_pandas_result( + *timedelta_dataframes_1, + lambda df: getattr(df, operation)(pd.Timedelta(2)), + expect_exception=True, + expect_exception_type=np.core._exceptions._UFuncBinaryResolutionError, + expect_exception_match=re.escape( + "ufunc 'multiply' cannot use operands with types dtype(' native_pd.DataFrame: class TestDefaultParameters: - @sql_count_checker(query_count=QUERY_COUNT, join_count=JOIN_COUNT) + @sql_count_checker( + query_count=QUERY_COUNT_MULTI_LEVEL_INDEX, + join_count=JOIN_COUNT_MULTI_LEVEL_INDEX, + ) def test_no_diff(self, base_df): other_df = base_df.copy() eval_snowpark_pandas_result( @@ -72,6 +85,56 @@ def test_no_diff(self, base_df): # In snowpark pandas, the column index of the empty resulting frame # has the correct values and names, but the incorrect inferred_type # for some of its levels. Ignore that bug for now. + # TODO(SNOW-1510921): fix the bug. + check_index_type=False, + check_column_type=False, + ) + + @sql_count_checker( + # no joins because we can skip the joins when comparing df to df.copy() + query_count=QUERY_COUNT_SINGLE_LEVEL_INDEX, + join_count=0, + ) + def test_no_diff_timedelta(self): + eval_snowpark_pandas_result( + *create_test_dfs([pd.Timedelta(1)]), + lambda df: df.compare(df.copy()), + check_index_type=False, + check_column_type=False, + ) + + @sql_count_checker( + query_count=QUERY_COUNT_SINGLE_LEVEL_INDEX, + join_count=JOIN_COUNT_SINGLE_LEVEL_INDEX, + ) + def test_one_diff_timedelta(self): + base_snow_df, base_pandas_df = create_test_dfs( + [[pd.Timedelta(1), pd.Timedelta(2)]] + ) + other_snow_df, other_pandas_df = create_test_dfs( + [[pd.Timedelta(1), pd.Timedelta(3)]] + ) + eval_snowpark_pandas_result( + (base_snow_df, other_snow_df), + (base_pandas_df, other_pandas_df), + lambda t: t[0].compare(t[1]), + check_index_type=False, + check_column_type=False, + ) + + @sql_count_checker( + query_count=QUERY_COUNT_SINGLE_LEVEL_INDEX, + join_count=JOIN_COUNT_SINGLE_LEVEL_INDEX, + ) + def test_timedelta_compared_with_int(self): + base_snow_df, base_pandas_df = create_test_dfs([[pd.Timedelta(1), 2]]) + other_snow_df, other_pandas_df = create_test_dfs( + [[pd.Timedelta(1), pd.Timedelta(2)]] + ) + eval_snowpark_pandas_result( + (base_snow_df, other_snow_df), + (base_pandas_df, other_pandas_df), + lambda t: t[0].compare(t[1]), check_index_type=False, check_column_type=False, ) @@ -86,7 +149,10 @@ def test_no_diff(self, base_df): ((3, 4), [201]), ], ) - @sql_count_checker(query_count=QUERY_COUNT, join_count=JOIN_COUNT) + @sql_count_checker( + query_count=QUERY_COUNT_MULTI_LEVEL_INDEX, + join_count=JOIN_COUNT_MULTI_LEVEL_INDEX, + ) def test_single_value_diff(self, base_df, position, new_value): # check that we are changing a value, so the test case is meaningful. assert not ( @@ -122,7 +188,10 @@ def test_default_index_on_both_axes(self, base_df): ), ) - @sql_count_checker(query_count=QUERY_COUNT, join_count=JOIN_COUNT) + @sql_count_checker( + query_count=QUERY_COUNT_MULTI_LEVEL_INDEX, + join_count=JOIN_COUNT_MULTI_LEVEL_INDEX, + ) def test_different_value_in_every_column_and_row(self, base_df): other_df = base_df.copy() other_df.iloc[0, 0] = "c" diff --git a/tests/integ/modin/frame/test_equals.py b/tests/integ/modin/frame/test_equals.py index e57c4180231..95b6b8ffd6f 100644 --- a/tests/integ/modin/frame/test_equals.py +++ b/tests/integ/modin/frame/test_equals.py @@ -15,6 +15,12 @@ "lhs, rhs, expected", [ ([1, 2, 3], [1, 2, 3], True), + pytest.param( + [pd.Timedelta(1), pd.Timedelta(2), pd.Timedelta(3)], + [pd.Timedelta(1), pd.Timedelta(2), pd.Timedelta(3)], + True, + id="timedelta", + ), ([1, 2, 3], [1, 2, 4], False), # different values ([1, 2, None], [1, 2, None], True), # nulls are considered equal ([1, 2, 3], [1.0, 2.0, 3.0], False), # float and integer types are not equal @@ -58,6 +64,8 @@ def test_equals_column_labels(lhs, rhs, expected): (np.float64, np.float32, True), (np.int16, "object", False), (np.int16, np.float16, False), + ("timedelta64[ns]", int, False), + ("timedelta64[ns]", float, False), ], ) @sql_count_checker(query_count=2, join_count=2) diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index bc26872380e..3c6362dd83c 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -16,6 +16,10 @@ data={"col1": [1, 2, 3], "col2": [3, 4, 5]}, index=native_pd.DatetimeIndex(["2024-01-01", "2024-02-01", "2024-03-01"]), ), + native_pd.DataFrame( + data={"col1": [1, 2, 3], "col2": [3, 4, 5]}, + index=native_pd.TimedeltaIndex(["0 days", "1 days", "3 days"]), + ), ] NATIVE_INDEX_TEST_DATA = [ @@ -36,6 +40,8 @@ tz="America/Los_Angeles", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), + native_pd.TimedeltaIndex(["0 days", "1 days", "3 days"]), + native_pd.TimedeltaIndex([100, 200, 300]), ] NATIVE_INDEX_UNIQUE_TEST_DATA = [ diff --git a/tests/integ/modin/index/test_timedelta_index_methods.py b/tests/integ/modin/index/test_timedelta_index_methods.py new file mode 100644 index 00000000000..1baafed24d2 --- /dev/null +++ b/tests/integ/modin/index/test_timedelta_index_methods.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import modin.pandas as pd +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import sql_count_checker + + +@sql_count_checker(query_count=3) +def test_timedelta_index_construction(): + # create from native pandas timedelta index. + index = native_pd.TimedeltaIndex(["1 days", "2 days", "3 days"]) + snow_index = pd.Index(index) + assert isinstance(snow_index, pd.TimedeltaIndex) + + # create from query compiler with timedelta type. + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=index) + snow_index = df.index + assert isinstance(snow_index, pd.TimedeltaIndex) + + # create from snowpark pandas timedelta index. + snow_index = pd.Index(pd.TimedeltaIndex([123])) + assert isinstance(snow_index, pd.TimedeltaIndex) + + # create by subtracting datetime index from another. + date_range1 = pd.date_range("2000-01-01", periods=10, freq="h") + date_range2 = pd.date_range("2001-05-01", periods=10, freq="h") + snow_index = date_range2 - date_range1 + assert isinstance(snow_index, pd.TimedeltaIndex) + + +@sql_count_checker(query_count=0) +@pytest.mark.parametrize( + "kwargs", + [ + {"unit": "ns"}, + {"freq": "M"}, + {"dtype": "int"}, + {"copy": True}, + {"name": "abc"}, + ], +) +def test_non_default_args(kwargs): + idx = pd.TimedeltaIndex(["1 days"]) + + name = list(kwargs.keys())[0] + value = list(kwargs.values())[0] + msg = f"Non-default argument '{name}={value}' when constructing Index with query compiler" + with pytest.raises(AssertionError, match=msg): + pd.TimedeltaIndex(query_compiler=idx._query_compiler, **kwargs) + + +@pytest.mark.parametrize( + "property", ["days", "seconds", "microseconds", "nanoseconds", "inferred_freq"] +) +@sql_count_checker(query_count=0) +def test_property_not_implemented(property): + snow_index = pd.TimedeltaIndex(["1 days", "2 days"]) + msg = f"Snowpark pandas does not yet support the property TimedeltaIndex.{property}" + with pytest.raises(NotImplementedError, match=msg): + getattr(snow_index, property) diff --git a/tests/integ/modin/series/test_astype.py b/tests/integ/modin/series/test_astype.py index ff65c677c05..9c00e9a675d 100644 --- a/tests/integ/modin/series/test_astype.py +++ b/tests/integ/modin/series/test_astype.py @@ -407,8 +407,8 @@ def test_python_datetime_astype_DatetimeTZDtype(seed): @sql_count_checker(query_count=1) @pytest.mark.parametrize( "data", - [[12345678, 9], [12345678, 2.6], [True, False], [1, "2"], ["1", "2"]], - ids=["int", "float", "boolean", "object", "string"], + [[12345678, 9], [12345678, 2.6], [True, False], [1, "2"]], + ids=["int", "float", "boolean", "object"], ) def test_astype_to_timedelta(data): native_series = native_pd.Series(data) @@ -419,24 +419,31 @@ def test_astype_to_timedelta(data): @sql_count_checker(query_count=2) -def test_astype_datetime_to_timedelta_negative(): - native_series = native_pd.Series( +def test_astype_to_timedelta_negative(): + native_datetime_series = native_pd.Series( data=[pd.to_datetime("2000-01-01"), pd.to_datetime("2001-01-01")] ) - snow_series = pd.Series(native_series) + snow_datetime_series = pd.Series(native_datetime_series) with SqlCounter(query_count=0): with pytest.raises( TypeError, match=re.escape("Cannot cast DatetimeArray to dtype timedelta64[ns]"), ): - native_series.astype("timedelta64[ns]") + native_datetime_series.astype("timedelta64[ns]") with pytest.raises( TypeError, match=re.escape( "dtype datetime64[ns] cannot be converted to timedelta64[ns]" ), ): - snow_series.astype("timedelta64[ns]") + snow_datetime_series.astype("timedelta64[ns]") + with SqlCounter(query_count=0): + snow_string_series = pd.Series(data=["2 days, 3 minutes"]) + with pytest.raises( + NotImplementedError, + match=re.escape("dtype object cannot be converted to timedelta64[ns]"), + ): + snow_string_series.astype("timedelta64[ns]") @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/series/test_equals.py b/tests/integ/modin/series/test_equals.py index b0f7af34b19..912726ff2f9 100644 --- a/tests/integ/modin/series/test_equals.py +++ b/tests/integ/modin/series/test_equals.py @@ -15,6 +15,12 @@ "lhs, rhs, expected", [ ([1, 2, 3], [1, 2, 3], True), + pytest.param( + [pd.Timedelta(1), pd.Timedelta(2), pd.Timedelta(3)], + [pd.Timedelta(1), pd.Timedelta(2), pd.Timedelta(3)], + True, + id="timedelta", + ), ([1, 2, None], [1, 2, None], True), # nulls are considered equal ([1, 2, 3], [1.0, 2.0, 3.0], False), # float and integer types are not equal ([1, 2, 3], ["1", "2", "3"], False), # integer and string types are not equal @@ -37,6 +43,8 @@ def test_equals_series(lhs, rhs, expected): (np.float64, np.float32, True), (np.int16, "object", False), (np.int16, np.float16, False), + ("timedelta64[ns]", int, False), + ("timedelta64[ns]", float, False), ], ) @sql_count_checker(query_count=2, join_count=2) diff --git a/tests/integ/modin/test_timedelta_ops.py b/tests/integ/modin/test_timedelta_ops.py index 2d38c1e372f..c60b91b3273 100644 --- a/tests/integ/modin/test_timedelta_ops.py +++ b/tests/integ/modin/test_timedelta_ops.py @@ -26,8 +26,8 @@ } -@sql_count_checker(query_count=0) -def test_td_case1_negative(): +@sql_count_checker(query_count=1) +def test_insert_datetime_difference_in_days(): data = TIME_DATA1 snow_df = pd.DataFrame(data) native_df = native_pd.DataFrame(data) @@ -41,80 +41,20 @@ def test_td_case1_negative(): ) / np.timedelta64(1, "D") ).round() - # TODO SNOW-1635620: remove Exception raised when TimeDelta is implemented - with pytest.raises(NotImplementedError): - snow_df["month_lag"] = ( - ( - pd.to_datetime(snow_df["CREATED_AT"], format="%Y-%m-%d %H:%M:%S") - - pd.to_datetime( - snow_df["REPORTING_DATE"], format="%Y-%m-%d", errors="coerce" - ) - ) - / np.timedelta64(1, "D") - ).round() - assert_series_equal(snow_df["month_lag"], native_df["open_lag"]) - - -@sql_count_checker(query_count=0) -def test_td_case2_negative(): - data = TIME_DATA1 - snow_df = pd.DataFrame(data) - native_df = native_pd.DataFrame(data) - native_df["open_lag"] = ( - ( - native_pd.to_datetime(native_df["CREATED_AT"], format="%Y-%m-%d %H:%M:%S") - - native_pd.to_datetime( - native_df["OPEN_DATE"], format="%Y-%m-%d", errors="coerce" - ) - ) - / np.timedelta64(1, "D") - ).round() - # TODO SNOW-1635620: remove Exception raised when TimeDelta is implemented - with pytest.raises(NotImplementedError): - snow_df["open_lag"] = ( - ( - pd.to_datetime(snow_df["CREATED_AT"], format="%Y-%m-%d %H:%M:%S") - - pd.to_datetime( - snow_df["OPEN_DATE"], format="%Y-%m-%d", errors="coerce" - ) - ) - / np.timedelta64(1, "D") - ).round() - assert_series_equal(snow_df["open_lag"], native_df["open_lag"]) - - -@sql_count_checker(query_count=0) -def test_td_case3_negative(): - data = TIME_DATA1 - snow_df = pd.DataFrame(data) - native_df = native_pd.DataFrame(data) - - native_df["close_lag"] = ( + snow_df["month_lag"] = ( ( - native_pd.to_datetime(native_df["CREATED_AT"], format="%Y-%m-%d %H:%M:%S") - - native_pd.to_datetime( - native_df["CLOSED_DATE"], format="%Y-%m-%d", errors="coerce" + pd.to_datetime(snow_df["CREATED_AT"], format="%Y-%m-%d %H:%M:%S") + - pd.to_datetime( + snow_df["REPORTING_DATE"], format="%Y-%m-%d", errors="coerce" ) ) / np.timedelta64(1, "D") ).round() - # TODO SNOW-1635620: remove Exception raised when TimeDelta is implemented - with pytest.raises(NotImplementedError): - snow_df["close_lag"] = ( - ( - pd.to_datetime(snow_df["CREATED_AT"], format="%Y-%m-%d %H:%M:%S") - - pd.to_datetime( - snow_df["CLOSED_DATE"], format="%Y-%m-%d", errors="coerce" - ) - ) - / np.timedelta64(1, "D") - ).round() - - assert_series_equal(snow_df["close_lag"], native_df["close_lag"]) + assert_series_equal(snow_df["month_lag"], native_df["month_lag"]) @sql_count_checker(query_count=1) -def test_td_case4(): +def test_insert_datetime_difference(): data = { "bl_start_ts": [Timestamp("2017-03-01T12")], "green_light_ts": [Timestamp("2017-01-07T12")], @@ -131,7 +71,7 @@ def test_td_case4(): @sql_count_checker(query_count=0) -def test_td_case5_negative(): +def test_diff_timestamp_column_to_get_timedelta_negative(): data = { "Country": ["A", "B", "C", "D", "E"], "Agreement Signing Date": [ @@ -144,7 +84,7 @@ def test_td_case5_negative(): } snow_df = pd.DataFrame(data) native_df = native_pd.DataFrame(data) - # TODO SNOW-1635620: remove Exception raised when TimeDelta is implemented + # TODO SNOW-1641729: remove Exception raised when TimeDelta is implemented with pytest.raises(SnowparkSQLException): eval_snowpark_pandas_result( snow_df, diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index a322c7d34b8..98b2bdbfeef 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -36,6 +36,7 @@ array_intersection, array_position, array_prepend, + array_remove, array_size, array_slice, array_to_string, @@ -2823,6 +2824,34 @@ def test_array_append(session): ) +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="array_remove is not yet supported in local testing mode.", +) +def test_array_remove(session): + Utils.check_answer( + [ + Row("[\n 2,\n 3\n]"), + Row("[\n 6,\n 7\n]"), + ], + TestData.array1(session).select( + array_remove(array_remove(col("arr1"), lit(1)), lit(8)) + ), + sort=False, + ) + + Utils.check_answer( + [ + Row("[\n 2,\n 3\n]"), + Row("[\n 6,\n 7\n]"), + ], + TestData.array1(session).select( + array_remove(array_remove(col("arr1"), 1), lit(8)) + ), + sort=False, + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="array_cat is not yet supported in local testing mode.", diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index 4b2f538ea40..3d9f2e22b24 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -175,9 +175,8 @@ def check_plan_queries( # the cte optimization is not kicking in when sql simplifier disabled, because # the cte_optimization_enabled is set to False when constructing the plan for df2, # and place_holder is not propogated. - # TODO (SNOW-1541096): revisit this test once the cte optimization is switched to the - # new compilation infra. - cte_applied=session.sql_simplifier_enabled, + cte_applied=session.sql_simplifier_enabled + or session._query_compilation_stage_enabled, exec_queries=df2._plan.execution_queries, ) diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 6aa115afcc2..87a91deab0e 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -10,6 +10,7 @@ from snowflake.connector.options import installed_pandas from snowflake.snowpark import Window from snowflake.snowpark._internal.analyzer import analyzer +from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType from snowflake.snowpark._internal.utils import ( TEMP_OBJECT_NAME_PREFIX, TempObjectType, @@ -35,6 +36,16 @@ ) ] +binary_operations = [ + lambda x, y: x.union_all(y), + lambda x, y: x.select("a").union(y.select("a")), + lambda x, y: x.except_(y), + lambda x, y: x.select("a").intersect(y.select("a")), + lambda x, y: x.join(y.select("a", "b"), rsuffix="_y"), + lambda x, y: x.select("a").join(y, how="outer", rsuffix="_y"), + lambda x, y: x.join(y.select("a"), how="left", rsuffix="_y"), +] + WITH = "WITH" @@ -104,18 +115,7 @@ def test_unary(session, action): check_result(session, df_action.union_all(df_action), expect_cte_optimized=True) -@pytest.mark.parametrize( - "action", - [ - lambda x, y: x.union_all(y), - lambda x, y: x.select("a").union(y.select("a")), - lambda x, y: x.except_(y), - lambda x, y: x.select("a").intersect(y.select("a")), - lambda x, y: x.join(y.select("a", "b"), rsuffix="_y"), - lambda x, y: x.select("a").join(y, how="outer", rsuffix="_y"), - lambda x, y: x.join(y.select("a"), how="left", rsuffix="_y"), - ], -) +@pytest.mark.parametrize("action", binary_operations) def test_binary(session, action): df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) check_result(session, action(df, df), expect_cte_optimized=True) @@ -138,6 +138,67 @@ def test_binary(session, action): assert len(plan_queries["post_actions"]) == 1 +@pytest.mark.parametrize("action", binary_operations) +def test_variable_binding_binary(session, action): + df1 = session.sql( + "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"] + ) + df2 = session.sql( + "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "c", 3, "d"] + ) + df3 = session.sql( + "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"] + ) + + check_result(session, action(df1, df3), expect_cte_optimized=True) + check_result(session, action(df1, df2), expect_cte_optimized=False) + + +def test_variable_binding_multiple(session): + if not session._query_compilation_stage_enabled: + pytest.skip( + "CTE query generation without the new query generation doesn't work correctly" + ) + + df1 = session.sql( + "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"] + ) + df2 = session.sql( + "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "c", 3, "d"] + ) + + df_res = df1.union(df1).union(df2) + check_result(session, df_res, expect_cte_optimized=True) + plan_queries = df_res._plan.execution_queries + + assert plan_queries[PlanQueryType.QUERIES][-1].params == [ + 1, + "a", + 2, + "b", + 1, + "c", + 3, + "d", + ] + + df_res = df2.union(df1).union(df2).union(df1) + check_result(session, df_res, expect_cte_optimized=True) + plan_queries = df_res._plan.execution_queries + + assert plan_queries[PlanQueryType.QUERIES][-1].params == [ + 1, + "a", + 2, + "b", + 1, + "c", + 3, + "d", + ] + assert plan_queries[PlanQueryType.QUERIES][-1].sql.count(WITH) == 1 + + @pytest.mark.parametrize( "action", [ diff --git a/tests/integ/test_df_to_snowpark_pandas.py b/tests/integ/test_df_to_snowpark_pandas.py index 05d51b1b38a..ede9b10e85c 100644 --- a/tests/integ/test_df_to_snowpark_pandas.py +++ b/tests/integ/test_df_to_snowpark_pandas.py @@ -5,6 +5,8 @@ # Tests behavior of to_snowpark_pandas() without explicitly initializing Snowpark pandas. +import sys + import pytest from snowflake.snowpark._internal.utils import TempObjectType @@ -47,9 +49,14 @@ def test_to_snowpark_pandas_no_modin(session, tmp_table_basic): # TODO: SNOW-1552497: after upgrading to modin 0.30.1, Snowpark pandas will support # all pandas 2.2.x, and this function call will raise a ModuleNotFoundError since # modin is not installed. + match = ( + "Snowpark pandas does not support Python 3.8. Please update to Python 3.9 or later" + if sys.version_info.major == 3 and sys.version_info.minor == 8 + else "does not match the supported pandas version in Snowpark pandas" + ) with pytest.raises( RuntimeError, - match="does not match the supported pandas version in Snowpark pandas", + match=match, ): snowpark_df.to_snowpark_pandas() else: diff --git a/tests/notebooks/modin/MIMICHealthcareDemo.ipynb b/tests/notebooks/modin/MIMICHealthcareDemo.ipynb index 40e82c78d6b..3f1849e52cd 100644 --- a/tests/notebooks/modin/MIMICHealthcareDemo.ipynb +++ b/tests/notebooks/modin/MIMICHealthcareDemo.ipynb @@ -392,10 +392,17 @@ "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:snowflake.snowpark.modin.plugin.utils.warning_message:Snowpark pandas support for Timedelta is not currently available.\n" + ] + } + ], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", - "# df[\"length_of_stay\"] = (df[\"outtime\"]-df[\"intime\"])/pd.Timedelta('1 hour')" + "df[\"length_of_stay\"] = (df[\"outtime\"]-df[\"intime\"])/pd.Timedelta('1 hour')" ] }, { @@ -405,8 +412,7 @@ "metadata": {}, "outputs": [], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", - "# df[\"age\"] = df[\"intime\"].dt.year-df[\"dob\"].dt.year" + "df[\"age\"] = df[\"intime\"].dt.year-df[\"dob\"].dt.year" ] }, { @@ -426,8 +432,7 @@ }, "outputs": [], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", - "# df = df[df[\"age\"]<100]" + "df = df[df[\"age\"]<100]" ] }, { @@ -447,10 +452,37 @@ "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:snowflake.snowpark.modin.plugin.utils.warning_message:DataFrame.plot materializes data to the local machine for plotting.\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", - "# df.plot(\"age\",\"length_of_stay\",kind=\"scatter\")" + "df.plot(\"age\",\"length_of_stay\",kind=\"scatter\")" ] }, { @@ -518,8 +550,8 @@ }, "outputs": [], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", - "# df[\"pre_icu_length_of_stay\"]= (df[\"intime\"]-df[\"admittime\"])/pd.Timedelta('1 day')" + "df[\"admittime\"] = pd.to_datetime(df[\"admittime\"])\n", + "df[\"pre_icu_length_of_stay\"]= (df[\"intime\"]-df[\"admittime\"])/pd.Timedelta('1 day')" ] }, { @@ -541,7 +573,7 @@ }, "outputs": [], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", + "# TODO(https://snowflakecomputing.atlassian.net/browse/SNOW-1640617): Implement Series.hist\n", "# df[\"pre_icu_length_of_stay\"].hist()" ] }, @@ -552,11 +584,18 @@ "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Percentage of ICU admissions within 1 day: 81.10%\n" + ] + } + ], "source": [ - "# TODO SNOW-1635620: uncomment when TimeDelta is implemented\n", - "# print(f\"Percentage of ICU admissions within 1 day: \\\n", - "# {len(df[df['pre_icu_length_of_stay']<1])/len(df)*100:.2f}%\")" + "print(f\"Percentage of ICU admissions within 1 day: \\\n", + " {len(df[df['pre_icu_length_of_stay']<1])/len(df)*100:.2f}%\")" ] }, { @@ -592,12 +631,12 @@ "3 HUMERAL FRACTURE\n", "4 ALCOHOLIC HEPATITIS\n", " ... \n", - "131 PERICARDIAL EFFUSION\n", - "132 ALTERED MENTAL STATUS\n", - "133 ACUTE RESPIRATORY DISTRESS SYNDROME;ACUTE RENA...\n", - "134 BRADYCARDIA\n", - "135 CHOLANGITIS\n", - "Name: diagnosis, Length: 136, dtype: object" + "122 SHORTNESS OF BREATH\n", + "123 PERICARDIAL EFFUSION\n", + "124 ACUTE RESPIRATORY DISTRESS SYNDROME;ACUTE RENA...\n", + "125 BRADYCARDIA\n", + "126 CHOLANGITIS\n", + "Name: diagnosis, Length: 127, dtype: object" ] }, "execution_count": 16, @@ -636,12 +675,12 @@ "3 HUMERAL FRACTURE\n", "4 ALCOHOLIC HEPATITIS\n", " ... \n", - "131 PERICARDIAL EFFUSION\n", - "132 ALTERED MENTAL STATUS\n", - "133 ACUTE RESPIRATORY DISTRESS SYNDROME ACUTE RENA...\n", - "134 BRADYCARDIA\n", - "135 CHOLANGITIS\n", - "Name: diagnosis, Length: 136, dtype: object" + "122 SHORTNESS OF BREATH\n", + "123 PERICARDIAL EFFUSION\n", + "124 ACUTE RESPIRATORY DISTRESS SYNDROME ACUTE RENA...\n", + "125 BRADYCARDIA\n", + "126 CHOLANGITIS\n", + "Name: diagnosis, Length: 127, dtype: object" ] }, "execution_count": 17, @@ -669,8 +708,8 @@ { "data": { "text/plain": [ - "[('SEPSIS', 10),\n", - " ('PNEUMONIA', 8),\n", + "[('SEPSIS', 9),\n", + " ('PNEUMONIA', 7),\n", " ('CONGESTIVE HEART FAILURE', 5),\n", " ('FEVER', 4),\n", " ('SHORTNESS OF BREATH', 4)]" @@ -769,8 +808,8 @@ "data": { "text/plain": [ "hospital_expire_flag\n", - "0 90\n", - "1 46\n", + "0 85\n", + "1 42\n", "Name: count, dtype: int64" ] }, @@ -892,15 +931,15 @@ " ...\n", " \n", " \n", - " 131\n", - " 0\n", + " 122\n", " 0\n", " 0\n", " 0\n", " 0\n", + " 1\n", " \n", " \n", - " 132\n", + " 123\n", " 0\n", " 0\n", " 0\n", @@ -908,7 +947,7 @@ " 0\n", " \n", " \n", - " 133\n", + " 124\n", " 0\n", " 0\n", " 0\n", @@ -916,7 +955,7 @@ " 0\n", " \n", " \n", - " 134\n", + " 125\n", " 0\n", " 0\n", " 0\n", @@ -924,7 +963,7 @@ " 0\n", " \n", " \n", - " 135\n", + " 126\n", " 0\n", " 0\n", " 0\n", @@ -933,7 +972,7 @@ " \n", " \n", "\n", - "

136 rows × 5 columns

\n", + "

127 rows × 5 columns

\n", "" ], "text/plain": [ @@ -944,13 +983,13 @@ "3 0 0 0 0 0\n", "4 0 0 0 0 0\n", ".. ... ... ... ... ...\n", - "131 0 0 0 0 0\n", - "132 0 0 0 0 0\n", - "133 0 0 0 0 0\n", - "134 0 0 0 0 0\n", - "135 0 0 0 0 0\n", + "122 0 0 0 0 1\n", + "123 0 0 0 0 0\n", + "124 0 0 0 0 0\n", + "125 0 0 0 0 0\n", + "126 0 0 0 0 0\n", "\n", - "[136 rows x 5 columns]" + "[127 rows x 5 columns]" ] }, "execution_count": 23, @@ -979,12 +1018,12 @@ "3 0\n", "4 1\n", " ..\n", - "131 0\n", - "132 1\n", - "133 0\n", - "134 0\n", - "135 0\n", - "Name: hospital_expire_flag, Length: 136, dtype: int8" + "122 0\n", + "123 0\n", + "124 0\n", + "125 0\n", + "126 0\n", + "Name: hospital_expire_flag, Length: 127, dtype: int8" ] }, "execution_count": 24, @@ -1460,7 +1499,7 @@ " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", - "
GaussianNB()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
GaussianNB()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "GaussianNB()" @@ -1508,7 +1547,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 29, @@ -1517,7 +1556,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1546,7 +1585,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Accuracy of the binary classifier = 0.64\n" + "Accuracy of the binary classifier = 0.62\n" ] } ], @@ -1586,6 +1625,18 @@ "\n", "Snowpark pandas lets you seamlessly move between feature engineering, visualization, and machine learning — all within the Python data ecosystem, while operating directly on the data in your data warehouse. \n" ] + }, + { + "cell_type": "markdown", + "id": "4e78a2bc", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "id": "698086ae", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/tests/unit/modin/test_class.py b/tests/unit/modin/test_class.py index 1d4b3881b7b..29aa1037d47 100644 --- a/tests/unit/modin/test_class.py +++ b/tests/unit/modin/test_class.py @@ -49,7 +49,6 @@ def test_class_equivalence(): assert pd.SparseDtype is native_pd.SparseDtype assert pd.StringDtype is native_pd.StringDtype assert pd.Timedelta is native_pd.Timedelta - assert pd.TimedeltaIndex is native_pd.TimedeltaIndex assert pd.Timestamp is native_pd.Timestamp assert pd.UInt8Dtype is native_pd.UInt8Dtype assert pd.UInt16Dtype is native_pd.UInt16Dtype