Skip to content

Commit

Permalink
Merge branch 'main' into vbudati/SNOW-1559025-datetimeindex-mean-std
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati authored Sep 13, 2024
2 parents c592a0f + 5b5c03b commit 550c4c7
Show file tree
Hide file tree
Showing 39 changed files with 1,319 additions and 439 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### New Features

- Added support for `TimedeltaIndex.mean` method.
- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`.
- Added support for `DatetimeIndex.mean` and `DatetimeIndex.std` methods.


Expand Down Expand Up @@ -125,6 +126,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det
- Added support for `Series.dt.total_seconds` method.
- Added support for `DataFrame.apply(axis=0)`.
- Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`.
- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`.

#### Improvements

Expand Down
4 changes: 2 additions & 2 deletions docs/source/modin/supported/datetime_index_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``snap`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``tz_convert`` | N | | |
| ``tz_convert`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``tz_localize`` | N | | |
| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``round`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
5 changes: 1 addition & 4 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,10 +956,7 @@ def do_resolve_with_resolved_children(
schema_query = schema_query_for_values_statement(logical_plan.output)

if logical_plan.data:
if (
len(logical_plan.output) * len(logical_plan.data)
< ARRAY_BIND_THRESHOLD
):
if not logical_plan.is_large_local_data:
return self.plan_builder.query(
values_statement(logical_plan.output, logical_plan.data),
logical_plan,
Expand Down
19 changes: 18 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,27 @@ def __init__(
self.data = data
self.schema_query = schema_query

@property
def is_large_local_data(self) -> bool:
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD

return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self.is_large_local_data:
# When the number of literals exceeds the threshold, we generate 3 queries:
# 1. create table query
# 2. insert into table query
# 3. select * from table query
# We only consider the complexity from the final select * query since other queries
# are built based on it.
return {
PlanNodeCategory.COLUMN: 1,
}

# If we stay under the threshold, we generate a single query:
# select $1, ..., $m FROM VALUES (r11, r12, ..., r1m), (rn1, ...., rnm)
# TODO: use ARRAY_BIND_THRESHOLD
return {
PlanNodeCategory.COLUMN: len(self.output),
PlanNodeCategory.LITERAL: len(self.data) * len(self.output),
Expand Down
52 changes: 52 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ class TelemetryField(Enum):
QUERY_PLAN_HEIGHT = "query_plan_height"
QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes"
QUERY_PLAN_COMPLEXITY = "query_plan_complexity"
# temp table cleanup
TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup"
NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned"
NUM_TEMP_TABLES_CREATED = "num_temp_tables_created"
TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled"
TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = (
"snowpark_temp_table_cleanup_abnormal_exception"
)
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = (
"temp_table_cleanup_abnormal_exception_table_name"
)
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = (
"temp_table_cleanup_abnormal_exception_message"
)


# These DataFrame APIs call other DataFrame APIs
Expand Down Expand Up @@ -464,3 +478,41 @@ def send_large_query_optimization_skipped_telemetry(
},
}
self.send(message)

def send_temp_table_cleanup_telemetry(
self,
session_id: str,
temp_table_cleaner_enabled: bool,
num_temp_tables_cleaned: int,
num_temp_tables_created: int,
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled,
TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned,
TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created,
},
}
self.send(message)

def send_temp_table_cleanup_abnormal_exception_telemetry(
self,
session_id: str,
table_name: str,
exception_message: str,
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name,
TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message,
},
}
self.send(message)
89 changes: 40 additions & 49 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import logging
import weakref
from collections import defaultdict
from queue import Empty, Queue
from threading import Event, Thread
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable

Expand All @@ -33,74 +31,67 @@ def __init__(self, session: "Session") -> None:
# to its reference count for later temp table management
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# unused temp table will be put into the queue for cleanup
self.queue: Queue = Queue()
# thread for removing temp tables (running DROP TABLE sql)
self.cleanup_thread: Optional[Thread] = None
# An event managing a flag that indicates whether the cleaner is started
self.stop_event = Event()

def add(self, table: SnowflakeTable) -> None:
self.ref_count_map[table.name] += 1
# the finalizer will be triggered when it gets garbage collected
# and this table will be dropped finally
_ = weakref.finalize(table, self._delete_ref_count, table.name)

def _delete_ref_count(self, name: str) -> None:
def _delete_ref_count(self, name: str) -> None: # pragma: no cover
"""
Decrements the reference count of a temporary table,
and if the count reaches zero, puts this table in the queue for cleanup.
"""
self.ref_count_map[name] -= 1
if self.ref_count_map[name] == 0:
self.ref_count_map.pop(name)
# clean up
self.queue.put(name)
if self.session.auto_clean_up_temp_table_enabled:
self.drop_table(name)
elif self.ref_count_map[name] < 0:
logging.debug(
f"Unexpected reference count {self.ref_count_map[name]} for table {name}"
)

def process_cleanup(self) -> None:
while not self.stop_event.is_set():
try:
# it's non-blocking after timeout and become interruptable with stop_event
# it will raise an `Empty` exception if queue is empty after timeout,
# then we catch this exception and avoid breaking loop
table_name = self.queue.get(timeout=1)
self.drop_table(table_name)
except Empty:
continue

def drop_table(self, name: str) -> None:
def drop_table(self, name: str) -> None: # pragma: no cover
common_log_text = f"temp table {name} in session {self.session.session_id}"
logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}")
logging.debug(f"Ready to drop {common_log_text}")
query_id = None
try:
# TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported
with self.session._conn._conn.cursor() as cursor:
cursor.execute(
f"drop table if exists {name} /* internal query to drop unused temp table */",
_statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name},
async_job = self.session.sql(
f"drop table if exists {name} /* internal query to drop unused temp table */",
)._internal_collect_with_tag_no_telemetry(
block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}
)
query_id = async_job.query_id
logging.debug(f"Dropping {common_log_text} with query id {query_id}")
except Exception as ex: # pragma: no cover
warning_message = f"Failed to drop {common_log_text}, exception: {ex}"
logging.warning(warning_message)
if query_id is None:
# If no query_id is available, it means the query haven't been accepted by gs,
# and it won't occur in our job_etl_view, send a separate telemetry for recording.
self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry(
self.session.session_id,
name,
str(ex),
)
logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}")
except Exception as ex:
logging.warning(
f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}"
) # pragma: no cover

def is_alive(self) -> bool:
return self.cleanup_thread is not None and self.cleanup_thread.is_alive()

def start(self) -> None:
self.stop_event.clear()
if not self.is_alive():
self.cleanup_thread = Thread(target=self.process_cleanup)
self.cleanup_thread.start()

def stop(self) -> None:
"""
The cleaner will stop immediately and leave unfinished temp tables in the queue.
Stops the cleaner (no-op) and sends the telemetry.
"""
self.stop_event.set()
if self.is_alive():
self.cleanup_thread.join()
self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry(
self.session.session_id,
temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled,
num_temp_tables_cleaned=self.num_temp_tables_cleaned,
num_temp_tables_created=self.num_temp_tables_created,
)

@property
def num_temp_tables_created(self) -> int:
return len(self.ref_count_map)

@property
def num_temp_tables_cleaned(self) -> int:
# TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled
return sum(v == 0 for v in self.ref_count_map.values())
Loading

0 comments on commit 550c4c7

Please sign in to comment.