diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 8febb38ad1..a6e8df4cf1 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -9,7 +9,6 @@ import sys import tempfile import traceback -import warnings from sys import exit from typing import Callable, List, Optional @@ -46,6 +45,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module +from flytekit.utils.loop_handling import use_event_loop def get_version_message(): @@ -69,23 +69,6 @@ def _compute_array_job_index(): return offset -def _get_working_loop(): - """Returns a running event loop.""" - try: - return asyncio.get_running_loop() - except RuntimeError: - with warnings.catch_warnings(): - warnings.simplefilter("error", DeprecationWarning) - try: - return asyncio.get_event_loop_policy().get_event_loop() - # Since version 3.12, DeprecationWarning is emitted if there is no - # current event loop. - except DeprecationWarning: - loop = asyncio.get_event_loop_policy().new_event_loop() - asyncio.set_event_loop(loop) - return loop - - def _dispatch_execute( ctx: FlyteContext, load_task: Callable[[], PythonTask], @@ -125,7 +108,8 @@ def _dispatch_execute( if inspect.iscoroutine(outputs): # Handle eager-mode (async) tasks logger.info("Output is a coroutine") - outputs = _get_working_loop().run_until_complete(outputs) + loop = asyncio.get_running_loop() + outputs = loop.run_until_complete(outputs) # Step3a if isinstance(outputs, VoidPromise): @@ -418,7 +402,8 @@ def load_task(): f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + with use_event_loop(): + _dispatch_execute(ctx, load_task, inputs, output_prefix) def _execute_map_task( @@ -479,7 +464,8 @@ def load_task(): ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + with use_event_loop(): + _dispatch_execute(ctx, load_task, inputs, output_prefix) def normalize_inputs( diff --git a/flytekit/utils/__init__.py b/flytekit/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/utils/loop_handling.py b/flytekit/utils/loop_handling.py new file mode 100644 index 0000000000..544abb6271 --- /dev/null +++ b/flytekit/utils/loop_handling.py @@ -0,0 +1,34 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from signal import SIGINT, SIGTERM + +from flytekit.loggers import logger + + +def handler(loop, s: int): + loop.stop() + logger.debug(f"Shutting down loop at {id(loop)} via {s!s}") + loop.remove_signal_handler(SIGTERM) + loop.add_signal_handler(SIGINT, lambda: None) + + +@contextmanager +def use_event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + executor = ThreadPoolExecutor() + loop.set_default_executor(executor) + for sig in (SIGTERM, SIGINT): + loop.add_signal_handler(sig, handler, loop, sig) + try: + yield loop + finally: + tasks = asyncio.all_tasks(loop=loop) + for t in tasks: + logger.debug(f"canceling {t.get_name()}") + t.cancel() + group = asyncio.gather(*tasks, return_exceptions=True) + loop.run_until_complete(group) + executor.shutdown(wait=True) + loop.close() diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py index 898d11a5ba..6903b8184a 100644 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -11,7 +11,7 @@ from flytekit import dynamic, task, workflow -from flytekit.bin.entrypoint import _get_working_loop, _dispatch_execute +from flytekit.bin.entrypoint import _dispatch_execute from flytekit.core import context_manager from flytekit.core.promise import VoidPromise from flytekit.exceptions.user import FlyteValidationException @@ -281,26 +281,3 @@ async def eager_wf_flyte_directory() -> str: result = asyncio.run(eager_wf_flyte_directory()) assert result == "some data" - - -@mock.patch("flytekit.core.utils.load_proto_from_file") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.core.utils.write_proto_to_file") -def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop): - """Test that event loop is preserved after executing eager workflow via dispatch.""" - - @eager - async def eager_wf(): - await asyncio.sleep(0.1) - return - - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) - ) - ) as ctx: - _dispatch_execute(ctx, lambda: eager_wf, "inputs path", "outputs prefix") - loop_after_execute = asyncio.get_event_loop_policy().get_event_loop() - assert event_loop == loop_after_execute