Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1641644: Drop temp table directly at garbage collection instead of using multi-threading #2214

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ class TelemetryField(Enum):
QUERY_PLAN_HEIGHT = "query_plan_height"
QUERY_PLAN_NUM_DUPLICATE_NODES = "query_plan_num_duplicate_nodes"
QUERY_PLAN_COMPLEXITY = "query_plan_complexity"
# temp table cleanup
TYPE_TEMP_TABLE_CLEANUP = "snowpark_temp_table_cleanup"
NUM_TEMP_TABLES_CLEANED = "num_temp_tables_cleaned"
NUM_TEMP_TABLES_CREATED = "num_temp_tables_created"
TEMP_TABLE_CLEANER_ENABLED = "temp_table_cleaner_enabled"
TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION = (
"snowpark_temp_table_cleanup_abnormal_exception"
)
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME = (
"temp_table_cleanup_abnormal_exception_table_name"
)
TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE = (
"temp_table_cleanup_abnormal_exception_message"
)


# These DataFrame APIs call other DataFrame APIs
Expand Down Expand Up @@ -464,3 +478,41 @@ def send_large_query_optimization_skipped_telemetry(
},
}
self.send(message)

def send_temp_table_cleanup_telemetry(
self,
session_id: str,
temp_table_cleaner_enabled: bool,
num_temp_tables_cleaned: int,
num_temp_tables_created: int,
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_TEMP_TABLE_CLEANUP.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.TEMP_TABLE_CLEANER_ENABLED.value: temp_table_cleaner_enabled,
TelemetryField.NUM_TEMP_TABLES_CLEANED.value: num_temp_tables_cleaned,
TelemetryField.NUM_TEMP_TABLES_CREATED.value: num_temp_tables_created,
},
}
self.send(message)

def send_temp_table_cleanup_abnormal_exception_telemetry(
self,
session_id: str,
table_name: str,
exception_message: str,
) -> None:
message = {
**self._create_basic_telemetry_data(
TelemetryField.TYPE_TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION.value
),
TelemetryField.KEY_DATA.value: {
TelemetryField.SESSION_ID.value: session_id,
TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_TABLE_NAME.value: table_name,
TelemetryField.TEMP_TABLE_CLEANUP_ABNORMAL_EXCEPTION_MESSAGE.value: exception_message,
},
}
self.send(message)
89 changes: 40 additions & 49 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import logging
import weakref
from collections import defaultdict
from queue import Empty, Queue
from threading import Event, Thread
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable

Expand All @@ -33,74 +31,67 @@ def __init__(self, session: "Session") -> None:
# to its reference count for later temp table management
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# unused temp table will be put into the queue for cleanup
self.queue: Queue = Queue()
# thread for removing temp tables (running DROP TABLE sql)
self.cleanup_thread: Optional[Thread] = None
# An event managing a flag that indicates whether the cleaner is started
self.stop_event = Event()

def add(self, table: SnowflakeTable) -> None:
self.ref_count_map[table.name] += 1
# the finalizer will be triggered when it gets garbage collected
# and this table will be dropped finally
_ = weakref.finalize(table, self._delete_ref_count, table.name)

def _delete_ref_count(self, name: str) -> None:
def _delete_ref_count(self, name: str) -> None: # pragma: no cover
"""
Decrements the reference count of a temporary table,
and if the count reaches zero, puts this table in the queue for cleanup.
"""
self.ref_count_map[name] -= 1
if self.ref_count_map[name] == 0:
self.ref_count_map.pop(name)
# clean up
self.queue.put(name)
if self.session.auto_clean_up_temp_table_enabled:
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
self.drop_table(name)
elif self.ref_count_map[name] < 0:
logging.debug(
f"Unexpected reference count {self.ref_count_map[name]} for table {name}"
)

def process_cleanup(self) -> None:
while not self.stop_event.is_set():
try:
# it's non-blocking after timeout and become interruptable with stop_event
# it will raise an `Empty` exception if queue is empty after timeout,
# then we catch this exception and avoid breaking loop
table_name = self.queue.get(timeout=1)
self.drop_table(table_name)
except Empty:
continue

def drop_table(self, name: str) -> None:
def drop_table(self, name: str) -> None: # pragma: no cover
common_log_text = f"temp table {name} in session {self.session.session_id}"
logging.debug(f"Cleanup Thread: Ready to drop {common_log_text}")
logging.debug(f"Ready to drop {common_log_text}")
query_id = None
try:
# TODO SNOW-1556553: Remove this workaround once multi-threading of Snowpark session is supported
with self.session._conn._conn.cursor() as cursor:
cursor.execute(
f"drop table if exists {name} /* internal query to drop unused temp table */",
_statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name},
async_job = self.session.sql(
f"drop table if exists {name} /* internal query to drop unused temp table */",
)._internal_collect_with_tag_no_telemetry(
block=False, statement_params={DROP_TABLE_STATEMENT_PARAM_NAME: name}
)
query_id = async_job.query_id
logging.debug(f"Dropping {common_log_text} with query id {query_id}")
except Exception as ex: # pragma: no cover
warning_message = f"Failed to drop {common_log_text}, exception: {ex}"
logging.warning(warning_message)
if query_id is None:
# If no query_id is available, it means the query haven't been accepted by gs,
# and it won't occur in our job_etl_view, send a separate telemetry for recording.
self.session._conn._telemetry_client.send_temp_table_cleanup_abnormal_exception_telemetry(
self.session.session_id,
name,
str(ex),
)
logging.debug(f"Cleanup Thread: Successfully dropped {common_log_text}")
except Exception as ex:
logging.warning(
f"Cleanup Thread: Failed to drop {common_log_text}, exception: {ex}"
) # pragma: no cover

def is_alive(self) -> bool:
return self.cleanup_thread is not None and self.cleanup_thread.is_alive()

def start(self) -> None:
self.stop_event.clear()
if not self.is_alive():
self.cleanup_thread = Thread(target=self.process_cleanup)
self.cleanup_thread.start()

def stop(self) -> None:
"""
The cleaner will stop immediately and leave unfinished temp tables in the queue.
Stops the cleaner (no-op) and sends the telemetry.
"""
self.stop_event.set()
if self.is_alive():
self.cleanup_thread.join()
self.session._conn._telemetry_client.send_temp_table_cleanup_telemetry(
self.session.session_id,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also record the parameter value, so that we know when this telemetry is sent is it due to session close or parameter turn off

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

temp_table_cleaner_enabled=self.session.auto_clean_up_temp_table_enabled,
num_temp_tables_cleaned=self.num_temp_tables_cleaned,
num_temp_tables_created=self.num_temp_tables_created,
)

@property
def num_temp_tables_created(self) -> int:
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
return len(self.ref_count_map)

@property
def num_temp_tables_cleaned(self) -> int:
# TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled
return sum(v == 0 for v in self.ref_count_map.values())
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
39 changes: 27 additions & 12 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,6 @@ def __init__(
self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None
self._runtime_version_from_requirement: str = None
self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self)
if self._auto_clean_up_temp_table_enabled:
self._temp_table_auto_cleaner.start()

_logger.info("Snowpark Session information: %s", self._session_info)

def __enter__(self):
Expand Down Expand Up @@ -623,8 +620,8 @@ def close(self) -> None:
raise SnowparkClientExceptionMessages.SERVER_FAILED_CLOSE_SESSION(str(ex))
finally:
try:
self._conn.close()
self._temp_table_auto_cleaner.stop()
self._conn.close()
_logger.info("Closed session: %s", self._session_id)
finally:
_remove_session(self)
Expand Down Expand Up @@ -658,10 +655,33 @@ def auto_clean_up_temp_table_enabled(self) -> bool:
:meth:`DataFrame.cache_result` in the current session when the DataFrame is no longer referenced (i.e., gets garbage collected).
The default value is ``False``.
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved

Example::

>>> import gc
>>>
>>> def f(session: Session) -> str:
... df = session.create_dataframe(
... [[1, 2], [3, 4]], schema=["a", "b"]
... ).cache_result()
... return df.table_name
...
>>> session.auto_clean_up_temp_table_enabled = True
>>> table_name = f(session)
>>> assert table_name
>>> gc.collect() # doctest: +SKIP
>>>
>>> # The temporary table created by cache_result will be dropped when the DataFrame is no longer referenced
>>> # outside the function
>>> session.sql(f"show tables like '{table_name}'").count()
0

>>> session.auto_clean_up_temp_table_enabled = False

Note:
Copy link
Collaborator Author

@sfc-gh-jdu sfc-gh-jdu Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can still provide this guarantee, by checking all entries of the count map whether a count reaches 0 during garbage collection. But given we're not using watch thread, I don't think this guarantee makes much sense. We only need to do our best effort to clean up temp tables when this parameter is enabled. Alternatively, we can always easily add this guarantee later when the customer requests it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior is irrelevant with the threading work, it is a behavior of the cleaner. if we do not have this, what we are saying is we only clean up temp tables whose reference reach 0 after the cleaner is started. let's make sure this is documented clearly somewhere

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can say the temporary tables will only be dropped when this parameter is turned on during garbage collection, whereas the garbage collection in Python is triggered opportunistically and the timing is not guaranteed. I think it's also clear. Let me know wdyt.

Even if this parameter is ``False``, Snowpark still records temporary tables when
their corresponding DataFrame are garbage collected. Therefore, if you turn it on in the middle of your session or after turning it off,
the target temporary tables will still be cleaned up accordingly.
Temporary tables will only be dropped if this parameter is enabled during garbage collection.
If a temporary table is no longer referenced when the parameter is on, it will be dropped during garbage collection.
However, if garbage collection occurs while the parameter is off, the table will not be removed.
Note that Python's garbage collection is triggered opportunistically, with no guaranteed timing.
"""
return self._auto_clean_up_temp_table_enabled

Expand Down Expand Up @@ -755,11 +775,6 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None:
self._session_id, value
)
self._auto_clean_up_temp_table_enabled = value
is_alive = self._temp_table_auto_cleaner.is_alive()
if value and not is_alive:
self._temp_table_auto_cleaner.start()
elif not value and is_alive:
self._temp_table_auto_cleaner.stop()
else:
raise ValueError(
"value for auto_clean_up_temp_table_enabled must be True or False!"
Expand Down
48 changes: 48 additions & 0 deletions tests/integ/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,3 +1223,51 @@ def send_telemetry():
data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry)
assert data == expected_data
assert type_ == "snowpark_compilation_stage_statistics"


def test_temp_table_cleanup(session):
client = session._conn._telemetry_client

def send_telemetry():
client.send_temp_table_cleanup_telemetry(
session.session_id,
temp_table_cleaner_enabled=True,
num_temp_tables_cleaned=2,
num_temp_tables_created=5,
)

telemetry_tracker = TelemetryDataTracker(session)

expected_data = {
"session_id": session.session_id,
"temp_table_cleaner_enabled": True,
"num_temp_tables_cleaned": 2,
"num_temp_tables_created": 5,
}

data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry)
assert data == expected_data
assert type_ == "snowpark_temp_table_cleanup"


def test_temp_table_cleanup_exception(session):
client = session._conn._telemetry_client

def send_telemetry():
client.send_temp_table_cleanup_abnormal_exception_telemetry(
session.session_id,
table_name="table_name_placeholder",
exception_message="exception_message_placeholder",
)

telemetry_tracker = TelemetryDataTracker(session)

expected_data = {
"session_id": session.session_id,
"temp_table_cleanup_abnormal_exception_table_name": "table_name_placeholder",
"temp_table_cleanup_abnormal_exception_message": "exception_message_placeholder",
}

data, type_, _ = telemetry_tracker.extract_telemetry_log_data(-1, send_telemetry)
assert data == expected_data
assert type_ == "snowpark_temp_table_cleanup_abnormal_exception"
Loading
Loading