diff --git a/CHANGELOG.md b/CHANGELOG.md index c3da8d2dd0b..deb1a4d2fa5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ #### New Features - Added support for `TimedeltaIndex.mean` method. +- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. - Added support for `DatetimeIndex.mean` and `DatetimeIndex.std` methods. @@ -125,6 +126,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added support for `Series.dt.total_seconds` method. - Added support for `DataFrame.apply(axis=0)`. - Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`. +- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`. #### Improvements diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 325da109877..46cee7f6014 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -82,9 +82,9 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``snap`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_convert`` | N | | | +| ``tz_convert`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``tz_localize`` | N | | | +| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 76e91b7da92..d8622299ea9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if ( - len(logical_plan.output) * len(logical_plan.data) - < ARRAY_BIND_THRESHOLD - ): + if not logical_plan.is_large_local_data: return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index e3e032cd94b..aa8730dcf7f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -144,10 +144,27 @@ def __init__( self.data = data self.schema_query = schema_query + @property + def is_large_local_data(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: + if self.is_large_local_data: + # When the number of literals exceeds the threshold, we generate 3 queries: + # 1. create table query + # 2. insert into table query + # 3. select * from table query + # We only consider the complexity from the final select * query since other queries + # are built based on it. + return { + PlanNodeCategory.COLUMN: 1, + } + + # If we stay under the threshold, we generate a single query: # select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm) - # TODO: use ARRAY_BIND_THRESHOLD return { PlanNodeCategory.COLUMN: len(self.output), PlanNodeCategory.LITERAL: len(self.data) * len(self.output), diff --git a/src/snowflake/snowpark/_internal/telemetry.py b/src/snowflake/snowpark/_internal/telemetry.py index 8b9ef2acccb..aef60828334 100644 --- a/src/snowflake/snowpark/_internal/telemetry.py +++ b/src/snowflake/snowpark/_internal/telemetry.py @@ -79,6 +79,20 @@ class TelemetryField(Enum): QUERY_PLAN_HEIGHT = "query_plan_height" QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes" QUERY_PLAN_COMPLEXITY = "query_plan_complexity" + # temp table cleanup + TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup" + NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned" + NUM_TEMP_TABLES_CREATED = "num_temp_tables_created" + TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled" + TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = ( + "snowpark_temp_table_cleanup_abnormal_exception" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = ( + "temp_table_cleanup_abnormal_exception_table_name" + ) + TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = ( + "temp_table_cleanup_abnormal_exception_message" + ) # These DataFrame APIs call other DataFrame APIs @@ -464,3 +478,41 @@ def send_large_query_optimization_skipped_telemetry( }, } self.send(message) + + def send_temp_table_cleanup_telemetry( + self, + session_id: str, + temp_table_cleaner_enabled: bool, + num_temp_tables_cleaned: int, + num_temp_tables_created: int, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled, + TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned, + TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created, + }, + } + self.send(message) + + def send_temp_table_cleanup_abnormal_exception_telemetry( + self, + session_id: str, + table_name: str, + exception_message: str, + ) -> None: + message = { + **self._create_basic_telemetry_data( + TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value + ), + TelemetryField.KEY_DATA.value: { + TelemetryField.SESSION_ID.value: session_id, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name, + TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message, + }, + } + self.send(message) diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index b9055c6fc58..4fa17498d34 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -4,9 +4,7 @@ import logging import weakref from collections import defaultdict -from queue import Empty, Queue -from threading import Event, Thread -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable @@ -33,12 +31,6 @@ def __init__(self, session: "Session") -> None: # to its reference count for later temp table management # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) - # unused temp table will be put into the queue for cleanup - self.queue: Queue = Queue() - # thread for removing temp tables (running DROP TABLE sql) - self.cleanup_thread: Optional[Thread] = None - # An event managing a flag that indicates whether the cleaner is started - self.stop_event = Event() def add(self, table: SnowflakeTable) -> None: self.ref_count_map[table.name] += 1 @@ -46,61 +38,60 @@ def add(self, table: SnowflakeTable) -> None: # and this table will be dropped finally _ = weakref.finalize(table, self._delete_ref_count, table.name) - def _delete_ref_count(self, name: str) -> None: + def _delete_ref_count(self, name: str) -> None: # pragma: no cover """ Decrements the reference count of a temporary table, and if the count reaches zero, puts this table in the queue for cleanup. """ self.ref_count_map[name] -= 1 if self.ref_count_map[name] == 0: - self.ref_count_map.pop(name) - # clean up - self.queue.put(name) + if self.session.auto_clean_up_temp_table_enabled: + self.drop_table(name) elif self.ref_count_map[name] < 0: logging.debug( f"Unexpected reference count {self.ref_count_map[name]} for table {name}" ) - def process_cleanup(self) -> None: - while not self.stop_event.is_set(): - try: - # it's non-blocking after timeout and become interruptable with stop_event - # it will raise an `Empty` exception if queue is empty after timeout, - # then we catch this exception and avoid breaking loop - table_name = self.queue.get(timeout=1) - self.drop_table(table_name) - except Empty: - continue - - def drop_table(self, name: str) -> None: + def drop_table(self, name: str) -> None: # pragma: no cover common_log_text = f"temp table {name} in session {self.session.session_id}" - logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}") + logging.debug(f"Ready to drop {common_log_text}") + query_id = None try: - # TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported - with self.session._conn._conn.cursor() as cursor: - cursor.execute( - f"drop table if exists {name} /* internal query to drop unused temp table */", - _statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}, + async_job = self.session.sql( + f"drop table if exists {name} /* internal query to drop unused temp table */", + )._internal_collect_with_tag_no_telemetry( + block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name} + ) + query_id = async_job.query_id + logging.debug(f"Dropping {common_log_text} with query id {query_id}") + except Exception as ex: # pragma: no cover + warning_message = f"Failed to drop {common_log_text}, exception: {ex}" + logging.warning(warning_message) + if query_id is None: + # If no query_id is available, it means the query haven't been accepted by gs, + # and it won't occur in our job_etl_view, send a separate telemetry for recording. + self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry( + self.session.session_id, + name, + str(ex), ) - logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}") - except Exception as ex: - logging.warning( - f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}" - ) # pragma: no cover - - def is_alive(self) -> bool: - return self.cleanup_thread is not None and self.cleanup_thread.is_alive() - - def start(self) -> None: - self.stop_event.clear() - if not self.is_alive(): - self.cleanup_thread = Thread(target=self.process_cleanup) - self.cleanup_thread.start() def stop(self) -> None: """ - The cleaner will stop immediately and leave unfinished temp tables in the queue. + Stops the cleaner (no-op) and sends the telemetry. """ - self.stop_event.set() - if self.is_alive(): - self.cleanup_thread.join() + self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry( + self.session.session_id, + temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled, + num_temp_tables_cleaned=self.num_temp_tables_cleaned, + num_temp_tables_created=self.num_temp_tables_created, + ) + + @property + def num_temp_tables_created(self) -> int: + return len(self.ref_count_map) + + @property + def num_temp_tables_cleaned(self) -> int: + # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled + return sum(v == 0 for v in self.ref_count_map.values()) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 01ccad8f430..0005df924db 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Iterable from functools import partial from inspect import getmembers -from types import BuiltinFunctionType +from types import BuiltinFunctionType, MappingProxyType from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union import numpy as np @@ -56,6 +56,7 @@ stddev, stddev_pop, sum as sum_, + trunc, var_pop, variance, when, @@ -65,6 +66,9 @@ OrderedDataFrame, OrderingColumn, ) +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + TimedeltaType, +) from snowflake.snowpark.modin.plugin._internal.utils import ( from_pandas_label, pandas_lit, @@ -85,7 +89,7 @@ } -def array_agg_keepna( +def _array_agg_keepna( column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn] ) -> Column: """ @@ -239,62 +243,63 @@ def _columns_coalescing_idxmax_idxmin_helper( ) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding snowflake builtin aggregation function for axis=0. If any change -# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and -# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly. -SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": count, - "mean": mean, - "min": min_, - "max": max_, - "idxmax": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmax" - ), - "idxmin": functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmin" - ), - "sum": sum_, - "median": median, - "skew": skew, - "std": stddev, - "var": variance, - "all": builtin("booland_agg"), - "any": builtin("boolor_agg"), - np.max: max_, - np.min: min_, - np.sum: sum_, - np.mean: mean, - np.median: median, - np.std: stddev, - np.var: variance, - "array_agg": array_agg, - "quantile": column_quantile, - "nunique": count_distinct, -} -GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = ( - "min", - "max", - "sum", - "mean", - "median", - "std", - np.max, - np.min, - np.sum, - np.mean, - np.median, - np.std, -) -GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = ( - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", -) +class _SnowparkPandasAggregation(NamedTuple): + """ + A representation of a Snowpark pandas aggregation. + + This structure gives us a common representation for an aggregation that may + have multiple aliases, like "sum" and np.sum. + """ + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool + + # This callable takes a single Snowpark column as input and aggregates the + # column on axis=0. If None, Snowpark pandas does not support this + # aggregation on axis=0. + axis_0_aggregation: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=True, i.e. not including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=True. + axis_1_aggregation_skipna: Optional[Callable] = None + + # This callable takes one or more Snowpark columns as input and + # the columns on axis=1 with skipna=False, i.e. including nulls in the + # aggregation. If None, Snowpark pandas does not support this aggregation + # on axis=1 with skipna=False. + axis_1_aggregation_keepna: Optional[Callable] = None + + +class SnowflakeAggFunc(NamedTuple): + """ + A Snowflake aggregation, including information about how the aggregation acts on SnowparkPandasType. + """ + + # The aggregation function in Snowpark. + # For aggregation on axis=0, this field should take a single Snowpark + # column and return the aggregated column. + # For aggregation on axis=1, this field should take an arbitrary number + # of Snowpark columns and return the aggregated column. + snowpark_aggregation: Callable + + # This field tells whether if types of all the inputs of the function are + # the same instance of SnowparkPandasType, the type of the result is the + # same instance of SnowparkPandasType. Note that this definition applies + # whether the aggregation is on axis=0 or axis=1. For example, the sum of + # a single timedelta column on axis 0 is another timedelta column. + # Equivalently, the sum of two timedelta columns along axis 1 is also + # another timedelta column. Therefore, preserves_snowpark_pandas_types for + # sum would be True. + preserves_snowpark_pandas_types: bool class AggFuncWithLabel(NamedTuple): @@ -413,35 +418,143 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: return sum(builtin("zeroifnull")(col) for col in cols) -# Map between the pandas input aggregation function (str or numpy function) and -# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation -# function may either be a builtin aggregation function, or a function taking in *arg columns -# that then calls the appropriate builtin aggregations. -SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "count": _columns_count, - "sum": _columns_coalescing_sum, - np.sum: _columns_coalescing_sum, - "min": _columns_coalescing_min, - "max": _columns_coalescing_max, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - np.min: _columns_coalescing_min, - np.max: _columns_coalescing_max, -} +def _create_pandas_to_snowpark_pandas_aggregation_map( + pandas_functions: Iterable[AggFuncTypeBase], + snowpark_pandas_aggregation: _SnowparkPandasAggregation, +) -> MappingProxyType[AggFuncTypeBase, _SnowparkPandasAggregation]: + """ + Create a map from the given pandas functions to the given _SnowparkPandasAggregation. -# These functions are called instead if skipna=False -SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { - "min": least, - "max": greatest, - "idxmax": _columns_coalescing_idxmax_idxmin_helper, - "idxmin": _columns_coalescing_idxmax_idxmin_helper, - # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark - # sum_, since Snowpark sum_ gets the sum of all rows within a single column. - "sum": lambda *cols: sum(cols), - np.sum: lambda *cols: sum(cols), - np.min: least, - np.max: greatest, -} + Args; + pandas_functions: The pandas functions that map to the given aggregation. + snowpark_pandas_aggregation: The aggregation to map to + + Returns: + The map. + """ + return MappingProxyType({k: snowpark_pandas_aggregation for k in pandas_functions}) + + +# Map between the pandas input aggregation function (str or numpy function) and +# _SnowparkPandasAggregation representing information about applying the +# aggregation in Snowpark pandas. +_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[ + AggFuncTypeBase, _SnowparkPandasAggregation +] = MappingProxyType( + { + "count": _SnowparkPandasAggregation( + axis_0_aggregation=count, + axis_1_aggregation_skipna=_columns_count, + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("mean", np.mean), + _SnowparkPandasAggregation( + axis_0_aggregation=mean, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("min", np.min), + _SnowparkPandasAggregation( + axis_0_aggregation=min_, + axis_1_aggregation_keepna=least, + axis_1_aggregation_skipna=_columns_coalescing_min, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("max", np.max), + _SnowparkPandasAggregation( + axis_0_aggregation=max_, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=_columns_coalescing_max, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("sum", np.sum), + _SnowparkPandasAggregation( + axis_0_aggregation=sum_, + # IMPORTANT: count and sum use python builtin sum to invoke + # __add__ on each column rather than Snowpark sum_, since + # Snowpark sum_ gets the sum of all rows within a single column. + axis_1_aggregation_keepna=lambda *cols: sum(cols), + axis_1_aggregation_skipna=_columns_coalescing_sum, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("median", np.median), + _SnowparkPandasAggregation( + axis_0_aggregation=median, + preserves_snowpark_pandas_types=True, + ), + ), + "idxmax": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "idxmin": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "skew": _SnowparkPandasAggregation( + axis_0_aggregation=skew, + preserves_snowpark_pandas_types=True, + ), + "all": _SnowparkPandasAggregation( + # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("booland_agg")(col(c)), pandas_lit(True) + ), + preserves_snowpark_pandas_types=False, + ), + "any": _SnowparkPandasAggregation( + # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("boolor_agg")(col(c)), pandas_lit(False) + ), + preserves_snowpark_pandas_types=False, + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("std", np.std), + _SnowparkPandasAggregation( + axis_0_aggregation=stddev, + preserves_snowpark_pandas_types=True, + ), + ), + **_create_pandas_to_snowpark_pandas_aggregation_map( + ("var", np.var), + _SnowparkPandasAggregation( + axis_0_aggregation=variance, + # variance units are the square of the input column units, so + # variance does not preserve types. + preserves_snowpark_pandas_types=False, + ), + ), + "array_agg": _SnowparkPandasAggregation( + axis_0_aggregation=array_agg, + preserves_snowpark_pandas_types=False, + ), + "quantile": _SnowparkPandasAggregation( + axis_0_aggregation=column_quantile, + preserves_snowpark_pandas_types=True, + ), + "nunique": _SnowparkPandasAggregation( + axis_0_aggregation=count_distinct, + preserves_snowpark_pandas_types=False, + ), + } +) class AggregateColumnOpParameters(NamedTuple): @@ -462,7 +575,7 @@ class AggregateColumnOpParameters(NamedTuple): agg_snowflake_quoted_identifier: str # the snowflake aggregation function to apply on the column - snowflake_agg_func: Callable + snowflake_agg_func: SnowflakeAggFunc # the columns specifying the order of rows in the column. This is only # relevant for aggregations that depend on row order, e.g. summing a string @@ -471,88 +584,108 @@ class AggregateColumnOpParameters(NamedTuple): def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: - return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP + return agg_func in _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION def get_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0 -) -> Optional[Callable]: + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] +) -> Optional[SnowflakeAggFunc]: """ Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. If no corresponding snowflake aggregation function can be found, return None. """ - if axis == 0: - snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func) - if snowflake_agg_func == stddev or snowflake_agg_func == variance: - # for aggregation function std and var, we only support ddof = 0 or ddof = 1. - # when ddof is 1, std is mapped to stddev, var is mapped to variance - # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop - # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 - ddof = agg_kwargs.get("ddof", 1) - if ddof != 1 and ddof != 0: - return None - if ddof == 0: - return stddev_pop if snowflake_agg_func == stddev else var_pop - elif snowflake_agg_func == column_quantile: - interpolation = agg_kwargs.get("interpolation", "linear") - q = agg_kwargs.get("q", 0.5) - if interpolation not in ("linear", "nearest"): - return None - if not is_scalar(q): - # SNOW-1062878 Because list-like q would return multiple rows, calling quantile - # through the aggregate frontend in this manner is unsupported. - return None - return lambda col: column_quantile(col, interpolation, q) - elif agg_func in ("all", "any"): - # If there are no rows in the input frame, the function will also return NULL, which should - # instead by TRUE for "all" and FALSE for "any". - # Need to wrap column name in IDENTIFIER, or else the agg function will treat the name - # as a string literal. - # The generated SQL expression for "all" is - # IFNULL(BOOLAND_AGG(IDENTIFIER("column_name")), TRUE) - # The expression for "any" is - # IFNULL(BOOLOR_AGG(IDENTIFIER("column_name")), FALSE) - default_value = bool(agg_func == "all") - return lambda col: builtin("ifnull")( - # mypy refuses to acknowledge snowflake_agg_func is non-NULL here - snowflake_agg_func(builtin("identifier")(col)), # type: ignore[misc] - pandas_lit(default_value), + if axis == 1: + return _generate_rowwise_aggregation_function(agg_func, agg_kwargs) + + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + + if snowpark_pandas_aggregation is None: + # We don't have any implementation at all for this aggregation. + return None + + snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation + + if snowpark_aggregation is None: + # We don't have an implementation on axis=0 for this aggregation. + return None + + # Rewrite some aggregations according to `agg_kwargs.` + if snowpark_aggregation == stddev or snowpark_aggregation == variance: + # for aggregation function std and var, we only support ddof = 0 or ddof = 1. + # when ddof is 1, std is mapped to stddev, var is mapped to variance + # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop + # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 + ddof = agg_kwargs.get("ddof", 1) + if ddof != 1 and ddof != 0: + return None + if ddof == 0: + snowpark_aggregation = ( + stddev_pop if snowpark_aggregation == stddev else var_pop ) - else: - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + elif snowpark_aggregation == column_quantile: + interpolation = agg_kwargs.get("interpolation", "linear") + q = agg_kwargs.get("q", 0.5) + if interpolation not in ("linear", "nearest"): + return None + if not is_scalar(q): + # SNOW-1062878 Because list-like q would return multiple rows, calling quantile + # through the aggregate frontend in this manner is unsupported. + return None + + def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: + return column_quantile(col, interpolation, q) - return snowflake_agg_func + assert ( + snowpark_aggregation is not None + ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." + return SnowflakeAggFunc( + snowpark_aggregation=snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def generate_rowwise_aggregation_function( +def _generate_rowwise_aggregation_function( agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any] -) -> Optional[Callable]: +) -> Optional[SnowflakeAggFunc]: """ Get a callable taking *arg columns to apply for an aggregation. Unlike get_snowflake_agg_func, this function may return a wrapped composition of Snowflake builtin functions depending on the values of the specified kwargs. """ - snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) - if not agg_kwargs.get("skipna", True): - snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get( - agg_func, snowflake_agg_func - ) + snowpark_pandas_aggregation = ( + _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION.get(agg_func) + ) + if snowpark_pandas_aggregation is None: + return None + snowpark_aggregation = ( + snowpark_pandas_aggregation.axis_1_aggregation_skipna + if agg_kwargs.get("skipna", True) + else snowpark_pandas_aggregation.axis_1_aggregation_keepna + ) + if snowpark_aggregation is None: + return None min_count = agg_kwargs.get("min_count", 0) if min_count > 0: + original_aggregation = snowpark_aggregation + # Create a case statement to check if the number of non-null values exceeds min_count # when min_count > 0, if the number of not NULL values is < min_count, return NULL. - def agg_func_wrapper(fn: Callable) -> Callable: - return lambda *cols: when( - _columns_count(*cols) < min_count, pandas_lit(None) - ).otherwise(fn(*cols)) + def snowpark_aggregation(*cols: SnowparkColumn) -> SnowparkColumn: + return when(_columns_count(*cols) < min_count, pandas_lit(None)).otherwise( + original_aggregation(*cols) + ) - return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func) - return snowflake_agg_func + return SnowflakeAggFunc( + snowpark_aggregation, + preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, + ) -def is_supported_snowflake_agg_func( - agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int +def _is_supported_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ check if the aggregation function is supported with snowflake. Current supported @@ -566,12 +699,14 @@ def is_supported_snowflake_agg_func( is_valid: bool. Whether it is valid to implement with snowflake or not. """ if isinstance(agg_func, tuple) and len(agg_func) == 2: + # For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, + # take the second part of the named aggregation. agg_func = agg_func[0] return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None -def are_all_agg_funcs_supported_by_snowflake( - agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int +def _are_all_agg_funcs_supported_by_snowflake( + agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: Literal[0, 1] ) -> bool: """ Check if all aggregation functions in the given list are snowflake supported @@ -582,14 +717,14 @@ def are_all_agg_funcs_supported_by_snowflake( return False. """ return all( - is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + _is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs ) def check_is_aggregation_supported_in_snowflake( agg_func: AggFuncType, agg_kwargs: dict[str, Any], - axis: int, + axis: Literal[0, 1], ) -> bool: """ check if distributed implementation with snowflake is available for the aggregation @@ -608,18 +743,18 @@ def check_is_aggregation_supported_in_snowflake( if is_dict_like(agg_func): return all( ( - are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + _are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) if is_list_like(value) and not is_named_tuple(value) - else is_supported_snowflake_agg_func(value, agg_kwargs, axis) + else _is_supported_snowflake_agg_func(value, agg_kwargs, axis) ) for value in agg_func.values() ) elif is_list_like(agg_func): - return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) - return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + return _are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) + return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) -def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: +def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: """ Is the given snowflake aggregation function needs to be applied on the numeric column. """ @@ -697,7 +832,7 @@ def drop_non_numeric_data_columns( ) -def generate_aggregation_column( +def _generate_aggregation_column( agg_column_op_params: AggregateColumnOpParameters, agg_kwargs: dict[str, Any], is_groupby_agg: bool, @@ -721,8 +856,14 @@ def generate_aggregation_column( SnowparkColumn after the aggregation function. The column is also aliased back to the original name """ snowpark_column = agg_column_op_params.snowflake_quoted_identifier - snowflake_agg_func = agg_column_op_params.snowflake_agg_func - if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( + snowflake_agg_func = agg_column_op_params.snowflake_agg_func.snowpark_aggregation + + if snowflake_agg_func in (variance, var_pop) and isinstance( + agg_column_op_params.data_type, TimedeltaType + ): + raise TypeError("timedelta64 type does not support var operations") + + if _is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( agg_column_op_params.data_type, BooleanType ): # if the column is a boolean column and the aggregation function requires numeric values, @@ -753,7 +894,7 @@ def generate_aggregation_column( # note that we always assume keepna for array_agg. TODO(SNOW-1040398): # make keepna treatment consistent across array_agg and other # aggregation methods. - agg_snowpark_column = array_agg_keepna( + agg_snowpark_column = _array_agg_keepna( snowpark_column, ordering_columns=agg_column_op_params.ordering_columns ) elif ( @@ -825,6 +966,19 @@ def generate_aggregation_column( ), f"No case expression is constructed with skipna({skipna}), min_count({min_count})" agg_snowpark_column = case_expr.otherwise(agg_snowpark_column) + if ( + isinstance(agg_column_op_params.data_type, TimedeltaType) + and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types + ): + # timedelta aggregations that produce timedelta results might produce + # a decimal type in snowflake, e.g. + # pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in + # Snowflake. We truncate the decimal part of the result, as pandas + # does. + agg_snowpark_column = cast( + trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type + ) + # rename the column to agg_column_quoted_identifier agg_snowpark_column = agg_snowpark_column.as_( agg_column_op_params.agg_snowflake_quoted_identifier @@ -857,7 +1011,7 @@ def aggregate_with_ordered_dataframe( is_groupby_agg = groupby_columns is not None agg_list: list[SnowparkColumn] = [ - generate_aggregation_column( + _generate_aggregation_column( agg_column_op_params=agg_col_op, agg_kwargs=agg_kwargs, is_groupby_agg=is_groupby_agg, @@ -973,7 +1127,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: ) -def generate_pandas_labels_for_agg_result_columns( +def _generate_pandas_labels_for_agg_result_columns( pandas_label: Hashable, num_levels: int, agg_func_list: list[AggFuncInfo], @@ -1102,7 +1256,7 @@ def generate_column_agg_info( ) # generate the pandas label and quoted identifier for the result aggregation columns, one # for each aggregation function to apply. - agg_col_labels = generate_pandas_labels_for_agg_result_columns( + agg_col_labels = _generate_pandas_labels_for_agg_result_columns( pandas_label_to_identifier.pandas_label, num_levels, agg_func_list, # type: ignore[arg-type] diff --git a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py index 3bf1062107e..e7a96b49ef1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py @@ -520,12 +520,15 @@ def single_pivot_helper( data_column_snowflake_quoted_identifiers: new data column snowflake quoted identifiers this pivot result data_column_pandas_labels: new data column pandas labels for this pivot result """ - snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {}) - if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func): + snowflake_agg_func = get_snowflake_agg_func(pandas_aggr_func_name, {}, axis=0) + if snowflake_agg_func is None or not is_supported_snowflake_pivot_agg_func( + snowflake_agg_func.snowpark_aggregation + ): # TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations raise ErrorMessage.not_implemented( f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments." ) + snowpark_aggr_func = snowflake_agg_func.snowpark_aggregation pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair @@ -1231,17 +1234,19 @@ def get_margin_aggregation( Returns: Snowpark column expression for the aggregation function result. """ - resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}) + resolved_aggfunc = get_snowflake_agg_func(aggfunc, {}, axis=0) # This would have been resolved during the original pivot at an early stage. assert resolved_aggfunc is not None, "resolved_aggfunc is None" - aggfunc_expr = resolved_aggfunc(snowflake_quoted_identifier) + aggregation_expression = resolved_aggfunc.snowpark_aggregation( + snowflake_quoted_identifier + ) - if resolved_aggfunc == sum_: - aggfunc_expr = coalesce(aggfunc_expr, pandas_lit(0)) + if resolved_aggfunc.snowpark_aggregation == sum_: + aggregation_expression = coalesce(aggregation_expression, pandas_lit(0)) - return aggfunc_expr + return aggregation_expression def expand_pivot_result_with_pivot_table_margins_no_groupby_columns( diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index f8629e664f3..3b714087535 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -525,7 +525,7 @@ def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column: The column after conversion to the specified timezone """ if tz is None: - return convert_timezone(pandas_lit("UTC"), column) + return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column)) else: if isinstance(tz, dt.tzinfo): tz_name = tz.tzname(None) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b5022bff46b..8ef3bdf9bee 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -149,8 +149,6 @@ ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, - GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, - GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, AggFuncInfo, AggFuncWithLabel, AggregateColumnOpParameters, @@ -161,7 +159,6 @@ convert_agg_func_arg_to_col_agg_func_map, drop_non_numeric_data_columns, generate_column_agg_info, - generate_rowwise_aggregation_function, get_agg_func_to_col_map, get_pandas_aggr_func_name, get_snowflake_agg_func, @@ -3556,42 +3553,22 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) - # Get the column aggregation functions used to check if the function - # preserves Snowpark pandas types. - agg_col_funcs = [] - for _, func in column_to_agg_func.items(): - if is_list_like(func) and not is_named_tuple(func): - for fn in func: - agg_col_funcs.append(fn.func) - else: - agg_col_funcs.append(func.func) # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] new_data_column_quoted_identifiers = [] new_data_column_snowpark_pandas_types = [] - for i in range(len(agg_col_ops)): - col_agg_op = agg_col_ops[i] - col_agg_func = agg_col_funcs[i] - new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label) + for agg_col_op in agg_col_ops: + new_data_column_pandas_labels.append(agg_col_op.agg_pandas_label) new_data_column_quoted_identifiers.append( - col_agg_op.agg_snowflake_quoted_identifier + agg_col_op.agg_snowflake_quoted_identifier + ) + new_data_column_snowpark_pandas_types.append( + agg_col_op.data_type + if isinstance(agg_col_op.data_type, SnowparkPandasType) + and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types + else None ) - if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: - new_data_column_snowpark_pandas_types.append( - col_agg_op.data_type - if isinstance(col_agg_op.data_type, SnowparkPandasType) - else None - ) - elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: - # In the case where the aggregation overrides the type of the output data column - # (e.g. any always returns boolean data columns), set the output Snowpark pandas type - # of the given column to None - new_data_column_snowpark_pandas_types.append(None) # type: ignore - else: - self._raise_not_implemented_error_for_timedelta() - new_data_column_snowpark_pandas_types = None # type: ignore - # The ordering of the named aggregations is changed by us when we process # the agg_kwargs into the func dict (named aggregations on the same # column are moved to be contiguous, see groupby.py::aggregate for an @@ -3644,7 +3621,7 @@ def convert_func_to_agg_func_info( ), agg_pandas_label=None, agg_snowflake_quoted_identifier=row_position_quoted_identifier, - snowflake_agg_func=min_, + snowflake_agg_func=get_snowflake_agg_func("min", agg_kwargs={}, axis=0), ordering_columns=internal_frame.ordering_columns, ) agg_col_ops.append(row_position_agg_column_op) @@ -5657,8 +5634,6 @@ def agg( args: the arguments passed for the aggregation kwargs: keyword arguments passed for the aggregation function. """ - self._raise_not_implemented_error_for_timedelta() - numeric_only = kwargs.get("numeric_only", False) # Call fallback if the aggregation function passed in the arg is currently not supported # by snowflake engine. @@ -5704,6 +5679,11 @@ def agg( not is_list_like(value) for value in func.values() ) if axis == 1: + if any( + isinstance(t, TimedeltaType) + for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values() + ): + ErrorMessage.not_implemented_for_timedelta("agg(axis=1)") if self.is_multiindex(): # TODO SNOW-1010307 fix axis=1 behavior with MultiIndex ErrorMessage.not_implemented( @@ -5761,9 +5741,9 @@ def agg( pandas_column_labels=frame.data_column_pandas_labels, ) if agg_arg in ("idxmin", "idxmax") - else generate_rowwise_aggregation_function(agg_arg, kwargs)( - *(col(c) for c in data_col_identifiers) - ) + else get_snowflake_agg_func( + agg_arg, kwargs, axis=1 + ).snowpark_aggregation(*(col(c) for c in data_col_identifiers)) for agg_arg in agg_args } pandas_labels = list(agg_col_map.keys()) @@ -5883,7 +5863,13 @@ def generate_agg_qc( index_column_snowflake_quoted_identifiers=[ agg_name_col_quoted_identifier ], - data_column_types=None, + data_column_types=[ + col.data_type + if isinstance(col.data_type, SnowparkPandasType) + and col.snowflake_agg_func.preserves_snowpark_pandas_types + else None + for col in col_agg_infos + ], index_column_types=None, ) return SnowflakeQueryCompiler(single_agg_dataframe) @@ -9129,7 +9115,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler": SnowflakeQueryCompiler Transposed new QueryCompiler object. """ - self._raise_not_implemented_error_for_timedelta() + if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1: + # In this case, transpose may lose types. + self._raise_not_implemented_error_for_timedelta() frame = self._modin_frame @@ -12513,8 +12501,6 @@ def _quantiles_single_col( column would allow us to create an accurate row position column, but would require a potentially expensive JOIN operator afterwards to apply the correct index labels. """ - self._raise_not_implemented_error_for_timedelta() - assert len(self._modin_frame.data_column_pandas_labels) == 1 if index is not None: @@ -12579,7 +12565,7 @@ def _quantiles_single_col( ], index_column_pandas_labels=[None], index_column_snowflake_quoted_identifiers=[index_identifier], - data_column_types=None, + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, index_column_types=None, ) # We cannot call astype() directly to convert an index column, so we replicate @@ -13613,6 +13599,16 @@ def _window_agg( } ).frame else: + snowflake_agg_func = get_snowflake_agg_func(agg_func, agg_kwargs, axis=0) + if snowflake_agg_func is None: + # We don't have test coverage for this situation because we + # test individual rolling and expanding methods we've implemented, + # like rolling_sum(), but other rolling methods raise + # NotImplementedError immediately. We also don't support rolling + # agg(), which might take us here. + ErrorMessage.not_implemented( # pragma: no cover + f"Window aggregation does not support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)}" + ) new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { # If aggregation is count use count on row_position_quoted_identifier @@ -13623,7 +13619,7 @@ def _window_agg( if agg_func == "count" else count(col(quoted_identifier)).over(window_expr) >= min_periods, - get_snowflake_agg_func(agg_func, agg_kwargs)( + snowflake_agg_func.snowpark_aggregation( # Expanding is cumulative so replace NULL with 0 for sum aggregation builtin("zeroifnull")(col(quoted_identifier)) if window_func == WindowFunction.EXPANDING @@ -14577,8 +14573,6 @@ def idxmax( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -14603,8 +14597,6 @@ def idxmin( Returns: SnowflakeQueryCompiler """ - self._raise_not_implemented_error_for_timedelta() - return self._idxmax_idxmin( func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only ) @@ -16668,6 +16660,7 @@ def dt_tz_localize( tz: Union[str, tzinfo], ambiguous: str = "raise", nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Localize tz-naive to tz-aware. @@ -16675,39 +16668,50 @@ def dt_tz_localize( tz : str, pytz.timezone, optional ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise" nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise" + include_index: Whether to include the index columns in the operation. Returns: BaseQueryCompiler New QueryCompiler containing values with localized time zone. """ + dtype = self.index_dtypes[0] if include_index else self.dtypes[0] + if not include_index: + method_name = "Series.dt.tz_localize" + else: + assert is_datetime64_any_dtype(dtype), "column must be datetime" + method_name = "DatetimeIndex.tz_localize" + if not isinstance(ambiguous, str) or ambiguous != "raise": - ErrorMessage.parameter_not_implemented_error( - "ambiguous", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if not isinstance(nonexistent, str) or nonexistent != "raise": - ErrorMessage.parameter_not_implemented_error( - "nonexistent", "Series.dt.tz_localize" - ) + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_localize_column(column, tz) + lambda column: tz_localize_column(column, tz), + include_index, ) ) - def dt_tz_convert(self, tz: Union[str, tzinfo]) -> "SnowflakeQueryCompiler": + def dt_tz_convert( + self, + tz: Union[str, tzinfo], + include_index: bool = False, + ) -> "SnowflakeQueryCompiler": """ Convert time-series data to the specified time zone. Args: tz : str, pytz.timezone + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler containing values with converted time zone. """ return SnowflakeQueryCompiler( self._modin_frame.apply_snowpark_function_to_columns( - lambda column: tz_convert_column(column, tz) + lambda column: tz_convert_column(column, tz), + include_index, ) ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index 2ad902e8a4e..d8982f11c97 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -961,7 +961,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex: DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None) """ - @datetime_index_not_implemented() def tz_convert(self, tz) -> DatetimeIndex: """ Convert tz-aware Datetime Array/Index from one time zone to another. @@ -1026,8 +1025,14 @@ def tz_convert(self, tz) -> DatetimeIndex: '2014-08-01 09:00:00'], dtype='datetime64[ns]', freq='h') """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_convert( + tz, + include_index=True, + ) + ) - @datetime_index_not_implemented() def tz_localize( self, tz, @@ -1105,21 +1110,29 @@ def tz_localize( Localize DatetimeIndex in US/Eastern time zone: - >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP - >>> tz_aware # doctest: +SKIP - DatetimeIndex(['2018-03-01 09:00:00-05:00', - '2018-03-02 09:00:00-05:00', + >>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') + >>> tz_aware + DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00', '2018-03-03 09:00:00-05:00'], - dtype='datetime64[ns, US/Eastern]', freq=None) + dtype='datetime64[ns, UTC-05:00]', freq=None) With the ``tz=None``, we can remove the time zone information while keeping the local time (not converted to UTC): - >>> tz_aware.tz_localize(None) # doctest: +SKIP + >>> tz_aware.tz_localize(None) DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00', '2018-03-03 09:00:00'], dtype='datetime64[ns]', freq=None) """ + # TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests. + return DatetimeIndex( + query_compiler=self._query_compiler.dt_tz_localize( + tz, + ambiguous, + nonexistent, + include_index=True, + ) + ) def round( self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 1dbb743aa32..9cb4ffa7327 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -32,12 +32,7 @@ from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable from pandas.core.dtypes.common import is_timedelta64_dtype -from snowflake.snowpark import functions as fn from snowflake.snowpark.modin.pandas import DataFrame, Series -from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( - AggregateColumnOpParameters, - aggregate_with_ordered_dataframe, -) from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -45,7 +40,6 @@ from snowflake.snowpark.modin.plugin.utils.error_message import ( timedelta_index_not_implemented, ) -from snowflake.snowpark.types import LongType _CONSTRUCTOR_DEFAULTS = { "unit": lib.no_default, @@ -433,19 +427,25 @@ def mean( raise ValueError( f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'" ) - # TODO SNOW-1620439: Reuse code from Series.mean. - frame = self._query_compiler._modin_frame - index_id = frame.index_column_snowflake_quoted_identifiers[0] - new_index_id = frame.ordered_dataframe.generate_snowflake_quoted_identifiers( - pandas_labels=["mean"] - )[0] - agg_column_op_params = AggregateColumnOpParameters( - index_id, LongType(), "mean", new_index_id, fn.mean, [] + pandas_dataframe_result = ( + # reset_index(drop=False) copies the index column of + # self._query_compiler into a new data column. Use `drop=False` + # so that we don't have to use SQL row_number() to generate a new + # index column. + self._query_compiler.reset_index(drop=False) + # Aggregate the data column. + .agg("mean", axis=0, args=(), kwargs={"skipna": skipna}) + # convert the query compiler to a pandas dataframe with + # dimensions 1x1 (note that the frame has a single row even + # if `self` is empty.) + .to_pandas() ) - mean_value = aggregate_with_ordered_dataframe( - frame.ordered_dataframe, [agg_column_op_params], {"skipna": skipna} - ).collect()[0][0] - return native_pd.Timedelta(np.nan if mean_value is None else int(mean_value)) + assert pandas_dataframe_result.shape == ( + 1, + 1, + ), "Internal error: aggregation result is not 1x1." + # Return the only element in the frame. + return pandas_dataframe_result.iloc[0, 0] @timedelta_index_not_implemented() def as_unit(self, unit: str) -> TimedeltaIndex: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 8da0794f139..8ffd4081473 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -582,9 +582,6 @@ def __init__( self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - if self._auto_clean_up_temp_table_enabled: - self._temp_table_auto_cleaner.start() - _logger.info("Snowpark Session information: %s", self._session_info) def __enter__(self): @@ -623,8 +620,8 @@ def close(self) -> None: raise SnowparkClientExceptionMessages.SERVER_FAILED_CLOSE_SESSION(str(ex)) finally: try: - self._conn.close() self._temp_table_auto_cleaner.stop() + self._conn.close() _logger.info("Closed session: %s", self._session_id) finally: _remove_session(self) @@ -658,10 +655,33 @@ def auto_clean_up_temp_table_enabled(self) -> bool: :meth:`DataFrame.cache_result` in the current session when the DataFrame is no longer referenced (i.e., gets garbage collected). The default value is ``False``. + Example:: + + >>> import gc + >>> + >>> def f(session: Session) -> str: + ... df = session.create_dataframe( + ... [[1, 2], [3, 4]], schema=["a", "b"] + ... ).cache_result() + ... return df.table_name + ... + >>> session.auto_clean_up_temp_table_enabled = True + >>> table_name = f(session) + >>> assert table_name + >>> gc.collect() # doctest: +SKIP + >>> + >>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced + >>> # outside the function + >>> session.sql(f"show tables like '{table_name}'").count() + 0 + + >>> session.auto_clean_up_temp_table_enabled = False + Note: - Even if this parameter is ``False``, Snowpark still records temporary tables when - their corresponding DataFrame are garbage collected. Therefore, if you turn it on in the middle of your session or after turning it off, - the target temporary tables will still be cleaned up accordingly. + Temporary tables will only be dropped if this parameter is enabled during garbage collection. + If a temporary table is no longer referenced when the parameter is on, it will be dropped during garbage collection. + However, if garbage collection occurs while the parameter is off, the table will not be removed. + Note that Python's garbage collection is triggered opportunistically, with no guaranteed timing. """ return self._auto_clean_up_temp_table_enabled @@ -755,11 +775,6 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None: self._session_id, value ) self._auto_clean_up_temp_table_enabled = value - is_alive = self._temp_table_auto_cleaner.is_alive() - if value and not is_alive: - self._temp_table_auto_cleaner.start() - elif not value and is_alive: - self._temp_table_auto_cleaner.stop() else: raise ValueError( "value for auto_clean_up_temp_table_enabled must be True or False!" diff --git a/tests/integ/modin/conftest.py b/tests/integ/modin/conftest.py index 2f24954e769..a7217b38a50 100644 --- a/tests/integ/modin/conftest.py +++ b/tests/integ/modin/conftest.py @@ -715,3 +715,30 @@ def numeric_test_data_4x4(): "C": [7, 10, 13, 16], "D": [8, 11, 14, 17], } + + +@pytest.fixture +def timedelta_native_df() -> pandas.DataFrame: + return pandas.DataFrame( + { + "A": [ + pd.Timedelta(days=1), + pd.Timedelta(days=2), + pd.Timedelta(days=3), + pd.Timedelta(days=4), + ], + "B": [ + pd.Timedelta(minutes=-1), + pd.Timedelta(minutes=0), + pd.Timedelta(minutes=5), + pd.Timedelta(minutes=6), + ], + "C": [ + None, + pd.Timedelta(nanoseconds=5), + pd.Timedelta(nanoseconds=0), + pd.Timedelta(nanoseconds=4), + ], + "D": pandas.to_timedelta([pd.NaT] * 4), + } + ) diff --git a/tests/integ/modin/frame/test_aggregate.py b/tests/integ/modin/frame/test_aggregate.py index b018682b6f8..ba68ae13734 100644 --- a/tests/integ/modin/frame/test_aggregate.py +++ b/tests/integ/modin/frame/test_aggregate.py @@ -187,6 +187,108 @@ def test_string_sum_with_nulls(): assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"])) +class TestTimedelta: + """Test aggregating dataframes containing timedelta columns.""" + + @pytest.mark.parametrize( + "func, union_count", + [ + param( + lambda df: df.aggregate(["min"]), + 0, + id="aggregate_list_with_one_element", + ), + param(lambda df: df.aggregate(x=("A", "max")), 0, id="single_named_agg"), + # this works since all results are timedelta and we don't need to do any concats. + param( + lambda df: df.aggregate({"B": "mean", "A": "sum"}), + 0, + id="dict_producing_two_timedeltas", + ), + # this works since even though we need to do concats, all the results are non-timdelta. + param( + lambda df: df.aggregate(x=("B", "all"), y=("B", "any")), + 1, + id="named_agg_producing_two_bools", + ), + # note following aggregation requires transpose + param(lambda df: df.aggregate(max), 0, id="aggregate_max"), + param(lambda df: df.min(), 0, id="min"), + param(lambda df: df.max(), 0, id="max"), + param(lambda df: df.count(), 0, id="count"), + param(lambda df: df.sum(), 0, id="sum"), + param(lambda df: df.mean(), 0, id="mean"), + param(lambda df: df.median(), 0, id="median"), + param(lambda df: df.std(), 0, id="std"), + param(lambda df: df.quantile(), 0, id="single_quantile"), + param(lambda df: df.quantile([0.01, 0.99]), 1, id="two_quantiles"), + ], + ) + def test_supported_axis_0(self, func, union_count, timedelta_native_df): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + func, + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1653126") + def test_axis_1(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), lambda df: df.sum(axis=1) + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}), + lambda df: df.aggregate({"B": ["count"], "A": "sum", "C": ["max", "min"]}), + lambda df: df.aggregate( + x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count") + ), + lambda df: df.aggregate(["min", np.max]), + lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")), + lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")), + lambda df: df.aggregate( + {"B": ["idxmax"], "A": "sum", "C": ["max", "idxmin"]} + ), + ], + ) + def test_agg_requires_concat_with_timedelta(self, timedelta_native_df, operation): + eval_snowpark_pandas_result(*create_test_dfs(timedelta_native_df), operation) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires transposing a one-row frame with integer and timedelta.", + ) + def test_agg_produces_timedelta_and_non_timedelta_type(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs(timedelta_native_df), + lambda df: df.aggregate({"B": "idxmax", "A": "sum"}), + ) + + @pytest.mark.parametrize( "func, expected_union_count", [ diff --git a/tests/integ/modin/frame/test_describe.py b/tests/integ/modin/frame/test_describe.py index a9668c5794f..4f1882d441d 100644 --- a/tests/integ/modin/frame/test_describe.py +++ b/tests/integ/modin/frame/test_describe.py @@ -358,3 +358,18 @@ def test_describe_object_file(resources_path): df = pd.read_csv(test_files.test_concat_file1_csv) native_df = df.to_pandas() eval_snowpark_pandas_result(df, native_df, lambda x: x.describe(include="O")) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df.describe(), + ) diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py index 72fe88968bc..87041060bd2 100644 --- a/tests/integ/modin/frame/test_idxmax_idxmin.py +++ b/tests/integ/modin/frame/test_idxmax_idxmin.py @@ -196,8 +196,18 @@ def test_idxmax_idxmin_with_dates(func, axis): @sql_count_checker(query_count=1) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) -@pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.xfail(reason="SNOW-1625380 TODO") +@pytest.mark.parametrize( + "axis", + [ + 0, + pytest.param( + 1, + marks=pytest.mark.xfail( + strict=True, raises=NotImplementedError, reason="SNOW-1653126" + ), + ), + ], +) def test_idxmax_idxmin_with_timedelta(func, axis): native_df = native_pd.DataFrame( data={ diff --git a/tests/integ/modin/frame/test_nunique.py b/tests/integ/modin/frame/test_nunique.py index d0cad8ec2ad..78098d34386 100644 --- a/tests/integ/modin/frame/test_nunique.py +++ b/tests/integ/modin/frame/test_nunique.py @@ -11,8 +11,13 @@ from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result -TEST_LABELS = np.array(["A", "B", "C", "D"]) -TEST_DATA = [[0, 1, 2, 3], [0, 0, 0, 0], [None, 0, None, 0], [None, None, None, None]] +TEST_LABELS = np.array(["A", "B", "C", "D", "E"]) +TEST_DATA = [ + [0, 1, 2, 3, pd.Timedelta(4)], + [0, 0, 0, 0, pd.Timedelta(0)], + [None, 0, None, 0, pd.Timedelta(0)], + [None, None, None, None, None], +] # which original dataframe (constructed from slicing) to test for TEST_SLICES = [ @@ -80,7 +85,7 @@ def test_dataframe_nunique_no_columns(native_df): [ pytest.param(None, id="default_columns"), pytest.param( - [["bar", "bar", "baz", "foo"], ["one", "two", "one", "two"]], + [["bar", "bar", "baz", "foo", "foo"], ["one", "two", "one", "two", "one"]], id="2D_columns", ), ], diff --git a/tests/integ/modin/frame/test_skew.py b/tests/integ/modin/frame/test_skew.py index 72fad6cebdc..94b7fd79c24 100644 --- a/tests/integ/modin/frame/test_skew.py +++ b/tests/integ/modin/frame/test_skew.py @@ -8,7 +8,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import assert_series_equal +from tests.integ.modin.utils import ( + assert_series_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @sql_count_checker(query_count=1) @@ -62,16 +66,22 @@ def test_skew_basic(): }, "kwargs": {"numeric_only": True, "skipna": True}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": True, + }, + }, ], ) @sql_count_checker(query_count=1) def test_skew(data): - native_df = native_pd.DataFrame(data["frame"]) - snow_df = pd.DataFrame(native_df) - assert_series_equal( - snow_df.skew(**data["kwargs"]), - native_df.skew(**data["kwargs"]), - rtol=1.0e-5, + eval_snowpark_pandas_result( + *create_test_dfs(data["frame"]), + lambda df: df.skew(**data["kwargs"]), + rtol=1.0e-5 ) @@ -103,6 +113,14 @@ def test_skew(data): }, "kwargs": {"level": 2}, }, + { + "frame": { + "A": [pd.Timedelta(1)], + }, + "kwargs": { + "numeric_only": False, + }, + }, ], ) @sql_count_checker(query_count=0) diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py index d5234dfbdb5..df8df44d47c 100644 --- a/tests/integ/modin/groupby/test_all_any.py +++ b/tests/integ/modin/groupby/test_all_any.py @@ -14,7 +14,11 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from snowflake.snowpark.exceptions import SnowparkSQLException from tests.integ.modin.sql_counter import sql_count_checker -from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_frame_equal, + create_test_dfs, + eval_snowpark_pandas_result, +) @pytest.mark.parametrize( @@ -109,3 +113,27 @@ def test_all_any_chained(): lambda df: df.apply(lambda ser: ser.str.len()) ) ) + + +@sql_count_checker(query_count=1) +def test_timedelta_any_with_nulls(): + """ + Test this case separately because pandas behavior is different from Snowpark pandas behavior. + + pandas bug that does not apply to Snowpark pandas: + https://github.com/pandas-dev/pandas/issues/59712 + """ + snow_df, native_df = create_test_dfs( + { + "key": ["a"], + "A": native_pd.Series([pd.NaT], dtype="timedelta64[ns]"), + }, + ) + assert_frame_equal( + native_df.groupby("key").any(), + native_pd.DataFrame({"A": [True]}, index=native_pd.Index(["a"], name="key")), + ) + assert_frame_equal( + snow_df.groupby("key").any(), + native_pd.DataFrame({"A": [False]}, index=native_pd.Index(["a"], name="key")), + ) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 09acd49bb21..cbf5b75d48c 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1096,60 +1096,81 @@ def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): ) -@pytest.mark.parametrize( - "agg_func", - [ - "count", - "sum", - "mean", - "median", - "std", - ], -) -@pytest.mark.parametrize("by", ["A", "B"]) -@sql_count_checker(query_count=1) -def test_timedelta(agg_func, by): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - - eval_snowpark_pandas_result( - snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() - ) - - -def test_timedelta_groupby_agg(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "16us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - "C": [True, False, False, True], - } +class TestTimedelta: + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "method", + [ + "count", + "mean", + "min", + "max", + "idxmax", + "idxmin", + "sum", + "median", + "std", + "nunique", + ], ) - snow_df = pd.DataFrame(native_df) - with SqlCounter(query_count=1): + @pytest.mark.parametrize("by", ["A", "B"]) + def test_aggregation_methods(self, method, by): eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: getattr(df.groupby(by), method)(), ) - with SqlCounter(query_count=1): - eval_snowpark_pandas_result( - snow_df, - native_df, + + @sql_count_checker(query_count=1) + @pytest.mark.parametrize( + "operation", + [ + lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}), + lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + lambda df: df.groupby("B").agg(["mean", "std"]), + lambda df: df.groupby("B").agg({"A": ["count", np.sum]}), + lambda df: df.groupby("B").agg({"A": "sum"}), + ], + ) + def test_agg(self, operation): + eval_snowpark_pandas_result( + *create_test_dfs( + native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + "C": [True, False, False, True], + } + ) + ), + operation, ) - with SqlCounter(query_count=1): + + @sql_count_checker(query_count=1) + def test_groupby_timedelta_var(self): + """ + Test that we can group by a timedelta column and take var() of an integer column. + + Note that we can't take the groupby().var() of the timedelta column because + var() is not defined for timedelta, in pandas or in Snowpark pandas. + """ eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: df.groupby("B").agg({"A": ["sum", "count"], "C": "median"}), + *create_test_dfs( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ), + lambda df: df.groupby("A").var(), ) diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py index 5da35806dd1..5e04d5a6fc2 100644 --- a/tests/integ/modin/groupby/test_groupby_first_last.py +++ b/tests/integ/modin/groupby/test_groupby_first_last.py @@ -46,6 +46,17 @@ [np.nan], ] ), + "col11_timedelta": [ + pd.Timedelta("1 days"), + None, + pd.Timedelta("2 days"), + None, + None, + None, + None, + None, + None, + ], } diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py index a009e1089b0..0c9c056c2a7 100644 --- a/tests/integ/modin/groupby/test_groupby_negative.py +++ b/tests/integ/modin/groupby/test_groupby_negative.py @@ -18,6 +18,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + create_test_dfs, eval_snowpark_pandas_result, ) @@ -559,20 +560,12 @@ def test_groupby_agg_invalid_min_count( @sql_count_checker(query_count=0) -def test_groupby_var_no_support_for_timedelta(): - native_df = native_pd.DataFrame( - { - "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] - ), - "B": [8, 8, 12, 10], - } - ) - snow_df = pd.DataFrame(native_df) - with pytest.raises( - NotImplementedError, - match=re.escape( - "SnowflakeQueryCompiler::groupby_agg is not yet implemented for Timedelta Type" +def test_timedelta_var_invalid(): + eval_snowpark_pandas_result( + *create_test_dfs( + [["key0", pd.Timedelta(1)]], ), - ): - snow_df.groupby("B").var() + lambda df: df.groupby(0).var(), + expect_exception=True, + expect_exception_type=TypeError, + ) diff --git a/tests/integ/modin/groupby/test_quantile.py b/tests/integ/modin/groupby/test_quantile.py index b14299fee63..940d366a7e2 100644 --- a/tests/integ/modin/groupby/test_quantile.py +++ b/tests/integ/modin/groupby/test_quantile.py @@ -64,6 +64,14 @@ # ), # All NA ([np.nan] * 5, [np.nan] * 5), + pytest.param( + pd.timedelta_range( + "1 days", + "5 days", + ), + pd.timedelta_range("1 second", "5 second"), + id="timedelta", + ), ], ) @pytest.mark.parametrize("q", [0, 0.5, 1]) diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index 84454fc4a27..26afd232c4f 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -79,4 +79,5 @@ tz="America/Los_Angeles", ), native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), + native_pd.TimedeltaIndex(["4 days", None, "-1 days", "5 days"]), ] diff --git a/tests/integ/modin/index/test_all_any.py b/tests/integ/modin/index/test_all_any.py index 267e7929ea1..499be6f03dc 100644 --- a/tests/integ/modin/index/test_all_any.py +++ b/tests/integ/modin/index/test_all_any.py @@ -25,6 +25,9 @@ native_pd.Index(["a", "b", "c", "d"]), native_pd.Index([5, None, 7]), native_pd.Index([], dtype="object"), + native_pd.Index([pd.Timedelta(0), None]), + native_pd.Index([pd.Timedelta(0)]), + native_pd.Index([pd.Timedelta(0), pd.Timedelta(1)]), ] NATIVE_INDEX_EMPTY_DATA = [ diff --git a/tests/integ/modin/index/test_argmax_argmin.py b/tests/integ/modin/index/test_argmax_argmin.py index 6d446a0a66a..7d42f3b88c9 100644 --- a/tests/integ/modin/index/test_argmax_argmin.py +++ b/tests/integ/modin/index/test_argmax_argmin.py @@ -18,6 +18,18 @@ native_pd.Index([4, None, 1, 3, 4, 1]), native_pd.Index([4, None, 1, 3, 4, 1], name="some name"), native_pd.Index([1, 10, 4, 3, 4]), + pytest.param( + native_pd.Index( + [ + pd.Timedelta(1), + pd.Timedelta(10), + pd.Timedelta(4), + pd.Timedelta(3), + pd.Timedelta(4), + ] + ), + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index a01e740ee84..094ffd1280c 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -7,6 +7,7 @@ import numpy as np import pandas as native_pd import pytest +import pytz import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker @@ -17,6 +18,46 @@ eval_snowpark_pandas_result, ) +timezones = pytest.mark.parametrize( + "tz", + [ + None, + # Use a subset of pytz.common_timezones containing a few timezones in each + *[ + param_for_one_tz + for tz in [ + "Africa/Abidjan", + "Africa/Timbuktu", + "America/Adak", + "America/Yellowknife", + "Antarctica/Casey", + "Asia/Dhaka", + "Asia/Manila", + "Asia/Shanghai", + "Atlantic/Stanley", + "Australia/Sydney", + "Canada/Pacific", + "Europe/Chisinau", + "Europe/Luxembourg", + "Indian/Christmas", + "Pacific/Chatham", + "Pacific/Wake", + "US/Arizona", + "US/Central", + "US/Eastern", + "US/Hawaii", + "US/Mountain", + "US/Pacific", + "UTC", + ] + for param_for_one_tz in ( + pytz.timezone(tz), + tz, + ) + ], + ], +) + @sql_count_checker(query_count=0) def test_datetime_index_construction(): @@ -233,6 +274,76 @@ def test_normalize(): ) +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_convert(tz): + native_index = native_pd.date_range( + start="2021-01-01", periods=5, freq="7h", tz="US/Eastern" + ) + native_index = native_index.append( + native_pd.DatetimeIndex([pd.NaT], tz="US/Eastern") + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_convert(tz).equals( + pd.DatetimeIndex(native_index.tz_convert(tz)) + ) + + +@sql_count_checker(query_count=1, join_count=1) +@timezones +def test_tz_localize(tz): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + + # Using eval_snowpark_pandas_result() was not possible because currently + # Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype + # even if the data contains a timezone. + assert snow_index.tz_localize(tz).equals( + pd.DatetimeIndex(native_index.tz_localize(tz)) + ) + + +@pytest.mark.parametrize( + "ambiguous, nonexistent", + [ + ("infer", "raise"), + ("NaT", "raise"), + (np.array([True, True, False]), "raise"), + ("raise", "shift_forward"), + ("raise", "shift_backward"), + ("raise", "NaT"), + ("raise", pd.Timedelta("1h")), + ("infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_tz_localize_negative(ambiguous, nonexistent): + native_index = native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + ) + snow_index = pd.DatetimeIndex(native_index) + with pytest.raises(NotImplementedError): + snow_index.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent) + + @pytest.mark.parametrize( "datetime_index_value", [ diff --git a/tests/integ/modin/series/test_aggregate.py b/tests/integ/modin/series/test_aggregate.py index fa354fda1fc..c3e40828d94 100644 --- a/tests/integ/modin/series/test_aggregate.py +++ b/tests/integ/modin/series/test_aggregate.py @@ -1,6 +1,8 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import re + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -17,6 +19,7 @@ MAP_DATA_AND_TYPE, MIXED_NUMERIC_STR_DATA_AND_TYPE, TIMESTAMP_DATA_AND_TYPE, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_test_series, eval_snowpark_pandas_result, ) @@ -358,3 +361,67 @@ def test_2_tuple_named_agg_errors_for_series(native_series, agg_kwargs): expect_exception_type=SpecificationError, assert_exception_equal=True, ) + + +class TestTimedelta: + """Test aggregating a timedelta series.""" + + @pytest.mark.parametrize( + "func, union_count, is_scalar", + [ + pytest.param(*v, id=str(i)) + for i, v in enumerate( + [ + (lambda series: series.aggregate(["min"]), 0, False), + (lambda series: series.aggregate({"A": "max"}), 0, False), + # this works since even though we need to do concats, all the results are non-timdelta. + (lambda df: df.aggregate(["all", "any", "count"]), 2, False), + # note following aggregation requires transpose + (lambda df: df.aggregate(max), 0, True), + (lambda df: df.min(), 0, True), + (lambda df: df.max(), 0, True), + (lambda df: df.count(), 0, True), + (lambda df: df.sum(), 0, True), + (lambda df: df.mean(), 0, True), + (lambda df: df.median(), 0, True), + (lambda df: df.std(), 0, True), + (lambda df: df.quantile(), 0, True), + (lambda df: df.quantile([0.01, 0.99]), 0, False), + ] + ) + ], + ) + def test_supported(self, func, union_count, timedelta_native_df, is_scalar): + with SqlCounter(query_count=1, union_count=union_count): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + func, + comparator=validate_scalar_result + if is_scalar + else assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, + ) + + @sql_count_checker(query_count=0) + def test_var_invalid(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda series: series.var(), + expect_exception=True, + expect_exception_type=TypeError, + assert_exception_equal=False, + expect_exception_match=re.escape( + "timedelta64 type does not support var operations" + ), + ) + + @sql_count_checker(query_count=0) + @pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", + ) + def test_unsupported_due_to_concat(self, timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_series(timedelta_native_df["A"]), + lambda df: df.agg(["count", "max"]), + ) diff --git a/tests/integ/modin/series/test_argmax_argmin.py b/tests/integ/modin/series/test_argmax_argmin.py index 607b36a27f3..e212e3ba2dd 100644 --- a/tests/integ/modin/series/test_argmax_argmin.py +++ b/tests/integ/modin/series/test_argmax_argmin.py @@ -18,6 +18,11 @@ ([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]), ([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["argmax", "argmin"]) diff --git a/tests/integ/modin/series/test_describe.py b/tests/integ/modin/series/test_describe.py index 9ecd2e33a3d..0f7bbda6c3a 100644 --- a/tests/integ/modin/series/test_describe.py +++ b/tests/integ/modin/series/test_describe.py @@ -11,6 +11,7 @@ from tests.integ.modin.sql_counter import sql_count_checker from tests.integ.modin.utils import ( assert_series_equal, + create_test_dfs, create_test_series, eval_snowpark_pandas_result, ) @@ -156,3 +157,18 @@ def test_describe_multiindex(data, index): eval_snowpark_pandas_result( *create_test_series(data, index=index), lambda ser: ser.describe() ) + + +@sql_count_checker(query_count=0) +@pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="requires concat(), which we cannot do with Timedelta.", +) +def test_timedelta(timedelta_native_df): + eval_snowpark_pandas_result( + *create_test_dfs( + timedelta_native_df, + ), + lambda df: df["A"].describe(), + ) diff --git a/tests/integ/modin/series/test_first_last_valid_index.py b/tests/integ/modin/series/test_first_last_valid_index.py index 1e8d052e10f..1930bdf1088 100644 --- a/tests/integ/modin/series/test_first_last_valid_index.py +++ b/tests/integ/modin/series/test_first_last_valid_index.py @@ -22,6 +22,10 @@ native_pd.Series([5, 6, 7, 8], index=["i", "am", "iron", "man"]), native_pd.Series([None, None, 2], index=[None, 1, 2]), native_pd.Series([None, None, 2], index=[None, None, None]), + pytest.param( + native_pd.Series([None, None, pd.Timedelta(2)], index=[None, 1, 2]), + id="timedelta", + ), ], ) def test_first_and_last_valid_index_series(native_series): diff --git a/tests/integ/modin/series/test_idxmax_idxmin.py b/tests/integ/modin/series/test_idxmax_idxmin.py index ea536240a42..e8e66a30f61 100644 --- a/tests/integ/modin/series/test_idxmax_idxmin.py +++ b/tests/integ/modin/series/test_idxmax_idxmin.py @@ -17,6 +17,11 @@ ([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]), ([1, None, 4, 3, 4], [None, "B", "C", "D", "E"]), ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + pytest.param( + [pd.Timedelta(1), None, pd.Timedelta(4), pd.Timedelta(3), pd.Timedelta(4)], + ["A", "B", "C", "D", "E"], + id="timedelta", + ), ], ) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) diff --git a/tests/integ/modin/series/test_nunique.py b/tests/integ/modin/series/test_nunique.py index bb20e9e4a53..3856dbc516a 100644 --- a/tests/integ/modin/series/test_nunique.py +++ b/tests/integ/modin/series/test_nunique.py @@ -6,6 +6,7 @@ import numpy as np import pandas as native_pd import pytest +from pytest import param import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker @@ -32,6 +33,20 @@ [True, None, False, True, None], [1.1, "a", None] * 4, [native_pd.to_datetime("2023-12-01"), native_pd.to_datetime("1999-09-09")] * 2, + param( + [ + native_pd.Timedelta(1), + native_pd.Timedelta(1), + native_pd.Timedelta(2), + None, + None, + ], + id="timedelta_with_nulls", + ), + param( + [native_pd.Timedelta(1), native_pd.Timedelta(1), native_pd.Timedelta(2)], + id="timedelta_without_nulls", + ), ], ) @pytest.mark.parametrize("dropna", [True, False]) diff --git a/tests/integ/test_query_plan_analysis.py b/tests/integ/test_query_plan_analysis.py index 0e8bb0d902d..81b852c46c1 100644 --- a/tests/integ/test_query_plan_analysis.py +++ b/tests/integ/test_query_plan_analysis.py @@ -98,6 +98,24 @@ def test_range_statement(session: Session): ) +def test_literal_complexity_for_snowflake_values(session: Session): + from snowflake.snowpark._internal.analyzer import analyzer + + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + assert_df_subtree_query_complexity( + df1, {PlanNodeCategory.COLUMN: 4, PlanNodeCategory.LITERAL: 4} + ) + + try: + original_threshold = analyzer.ARRAY_BIND_THRESHOLD + analyzer.ARRAY_BIND_THRESHOLD = 2 + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + # SELECT "A", "B" from (SELECT * FROM TEMP_TABLE) + assert_df_subtree_query_complexity(df2, {PlanNodeCategory.COLUMN: 3}) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_threshold + + def test_generator_table_function(session: Session): df1 = session.generator( seq1(1).as_("seq"), uniform(1, 10, 2).as_("uniform"), rowcount=150 diff --git a/tests/integ/test_telemetry.py b/tests/integ/test_telemetry.py index 7aaa5c9e5dd..39749de76f6 100644 --- a/tests/integ/test_telemetry.py +++ b/tests/integ/test_telemetry.py @@ -1223,3 +1223,51 @@ def send_telemetry(): data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) assert data == expected_data assert type_ == "snowpark_compilation_stage_statistics" + + +def test_temp_table_cleanup(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_telemetry( + session.session_id, + temp_table_cleaner_enabled=True, + num_temp_tables_cleaned=2, + num_temp_tables_created=5, + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleaner_enabled": True, + "num_temp_tables_cleaned": 2, + "num_temp_tables_created": 5, + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup" + + +def test_temp_table_cleanup_exception(session): + client = session._conn._telemetry_client + + def send_telemetry(): + client.send_temp_table_cleanup_abnormal_exception_telemetry( + session.session_id, + table_name="table_name_placeholder", + exception_message="exception_message_placeholder", + ) + + telemetry_tracker = TelemetryDataTracker(session) + + expected_data = { + "session_id": session.session_id, + "temp_table_cleanup_abnormal_exception_table_name": "table_name_placeholder", + "temp_table_cleanup_abnormal_exception_message": "exception_message_placeholder", + } + + data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry) + assert data == expected_data + assert type_ == "snowpark_temp_table_cleanup_abnormal_exception" diff --git a/tests/integ/test_temp_table_cleanup.py b/tests/integ/test_temp_table_cleanup.py index 4ac87661484..cdd97d49937 100644 --- a/tests/integ/test_temp_table_cleanup.py +++ b/tests/integ/test_temp_table_cleanup.py @@ -12,6 +12,7 @@ from snowflake.snowpark._internal.utils import ( TempObjectType, random_name_for_temp_object, + warning_dict, ) from snowflake.snowpark.functions import col from tests.utils import IS_IN_STORED_PROC @@ -25,40 +26,61 @@ WAIT_TIME = 1 +@pytest.fixture(autouse=True) +def setup(session): + auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled + session.auto_clean_up_temp_table_enabled = True + yield + session.auto_clean_up_temp_table_enabled = auto_clean_up_temp_table_enabled + + def test_basic(session): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = df1.select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df3 = df1.union_all(df2) df3.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df2 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 del df3 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 def test_function(session): + session._temp_table_auto_cleaner.ref_count_map.clear() table_name = None def f(session: Session) -> None: @@ -68,13 +90,16 @@ def f(session: Session) -> None: nonlocal table_name table_name = df.table_name assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() f(session) gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_name.split(".")) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.parametrize( @@ -86,33 +111,42 @@ def f(session: Session) -> None: ], ) def test_copy(session, copy_function): + session._temp_table_auto_cleaner.ref_count_map.clear() df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() table_name = df1.table_name table_ids = table_name.split(".") df1.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = copy_function(df1).select("*").filter(col("a") == 1) df2.collect() assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 2 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert session._table_exists(table_ids) assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids) - assert table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_reference_count_map_multiple_sessions(db_parameters, session): + session._temp_table_auto_cleaner.ref_count_map.clear() new_session = Session.builder.configs(db_parameters).create() + new_session.auto_clean_up_temp_table_enabled = True try: df1 = session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] @@ -120,43 +154,59 @@ def test_reference_count_map_multiple_sessions(db_parameters, session): table_name1 = df1.table_name table_ids1 = table_name1.split(".") assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 1 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 df2 = new_session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).cache_result() table_name2 = df2.table_name table_ids2 = table_name2.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - session._temp_table_auto_cleaner.start() del df1 gc.collect() time.sleep(WAIT_TIME) assert not session._table_exists(table_ids1) assert new_session._table_exists(table_ids2) assert session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 - assert new_session._temp_table_auto_cleaner.ref_count_map[table_name1] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + assert table_name1 not in new_session._temp_table_auto_cleaner.ref_count_map + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 - new_session._temp_table_auto_cleaner.start() del df2 gc.collect() time.sleep(WAIT_TIME) assert not new_session._table_exists(table_ids2) - assert session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert table_name2 not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 assert new_session._temp_table_auto_cleaner.ref_count_map[table_name2] == 0 + assert new_session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert new_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 finally: new_session.close() def test_save_as_table_no_drop(session): - session._temp_table_auto_cleaner.start() + session._temp_table_auto_cleaner.ref_count_map.clear() def f(session: Session, temp_table_name: str) -> None: session.create_dataframe( [[1, 2], [3, 4]], schema=["a", "b"] ).write.save_as_table(temp_table_name, table_type="temp") - assert session._temp_table_auto_cleaner.ref_count_map[temp_table_name] == 0 + assert temp_table_name not in session._temp_table_auto_cleaner.ref_count_map + assert session._temp_table_auto_cleaner.num_temp_tables_created == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 0 temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) f(session, temp_table_name) @@ -165,34 +215,25 @@ def f(session: Session, temp_table_name: str) -> None: assert session._table_exists([temp_table_name]) -def test_start_stop(session): - session._temp_table_auto_cleaner.stop() - - df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() - table_name = df1.table_name +def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): + warning_dict.clear() + with caplog.at_level(logging.WARNING): + session.auto_clean_up_temp_table_enabled = False + assert session.auto_clean_up_temp_table_enabled is False + assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).cache_result() + table_name = df.table_name table_ids = table_name.split(".") - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 1 - del df1 + del df gc.collect() - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 - assert not session._temp_table_auto_cleaner.queue.empty() - assert session._table_exists(table_ids) - - session._temp_table_auto_cleaner.start() time.sleep(WAIT_TIME) - assert session._temp_table_auto_cleaner.queue.empty() - assert not session._table_exists(table_ids) - - -def test_auto_clean_up_temp_table_enabled_parameter(db_parameters, session, caplog): - with caplog.at_level(logging.WARNING): - session.auto_clean_up_temp_table_enabled = True + assert session._table_exists(table_ids) + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 1 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 1 + session.auto_clean_up_temp_table_enabled = True assert session.auto_clean_up_temp_table_enabled is True - assert "auto_clean_up_temp_table_enabled is experimental" in caplog.text - assert session._temp_table_auto_cleaner.is_alive() - session.auto_clean_up_temp_table_enabled = False - assert session.auto_clean_up_temp_table_enabled is False - assert not session._temp_table_auto_cleaner.is_alive() + with pytest.raises( ValueError, match="value for auto_clean_up_temp_table_enabled must be True or False!", diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index 5434387ba71..6c9edfd024f 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -2,12 +2,20 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from types import MappingProxyType +from unittest import mock + import numpy as np import pytest +import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils +from snowflake.snowpark.functions import greatest, sum as sum_ from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + SnowflakeAggFunc, + _is_supported_snowflake_agg_func, + _SnowparkPandasAggregation, check_is_aggregation_supported_in_snowflake, - is_supported_snowflake_agg_func, + get_snowflake_agg_func, ) @@ -53,8 +61,8 @@ ("quantile", {}, 1, False), ], ) -def test_is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: - assert is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid +def test__is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, is_valid) -> None: + assert _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) is is_valid @pytest.mark.parametrize( @@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args( agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0 ) assert can_be_distributed == expected_result + + +@pytest.mark.parametrize( + "agg_func, agg_kwargs, axis, expected", + [ + (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)), + ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)), + ("test", {}, 0, None), + ], +) +def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected): + result = get_snowflake_agg_func(agg_func, agg_kwargs, axis) + if expected is None: + assert result is None + else: + assert result == expected + + +def test_get_snowflake_agg_func_with_no_implementation_on_axis_0(): + """Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0.""" + # We have to patch the internal dictionary + # _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is + # no real function that we support on axis=1 but not on axis=0. + with mock.patch.object( + aggregation_utils, + "_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION", + MappingProxyType( + { + "max": _SnowparkPandasAggregation( + preserves_snowpark_pandas_types=True, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=greatest, + ) + } + ), + ): + assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 262c9e82c44..370ee455d62 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -112,6 +112,7 @@ def test_used_scoped_temp_object(): def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._telemetry_client = mock.Mock() fake_connection.is_closed = MagicMock(return_value=False) exception_msg = "Mock exception for session.cancel_all" fake_connection.run_query = MagicMock(side_effect=Exception(exception_msg))