From a31b713dcd91d79baa969fc29af6b6970a5a4d1f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 01:37:14 -0700 Subject: [PATCH 1/5] try more complete lifecycle Signed-off-by: Yee Hing Tong --- flytekit/bin/entrypoint.py | 21 +++----------------- flytekit/utils/__init__.py | 0 flytekit/utils/loop_handling.py | 35 +++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 18 deletions(-) create mode 100644 flytekit/utils/__init__.py create mode 100644 flytekit/utils/loop_handling.py diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 8febb38ad1..692f8e594b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -46,6 +46,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.tools.fast_registration import download_distribution as _download_distribution from flytekit.tools.module_loader import load_object_from_module +from flytekit.utils.loop_handling import use_event_loop def get_version_message(): @@ -69,23 +70,6 @@ def _compute_array_job_index(): return offset -def _get_working_loop(): - """Returns a running event loop.""" - try: - return asyncio.get_running_loop() - except RuntimeError: - with warnings.catch_warnings(): - warnings.simplefilter("error", DeprecationWarning) - try: - return asyncio.get_event_loop_policy().get_event_loop() - # Since version 3.12, DeprecationWarning is emitted if there is no - # current event loop. - except DeprecationWarning: - loop = asyncio.get_event_loop_policy().new_event_loop() - asyncio.set_event_loop(loop) - return loop - - def _dispatch_execute( ctx: FlyteContext, load_task: Callable[[], PythonTask], @@ -125,7 +109,8 @@ def _dispatch_execute( if inspect.iscoroutine(outputs): # Handle eager-mode (async) tasks logger.info("Output is a coroutine") - outputs = _get_working_loop().run_until_complete(outputs) + with use_event_loop() as loop: + outputs = loop.run_until_complete(outputs) # Step3a if isinstance(outputs, VoidPromise): diff --git a/flytekit/utils/__init__.py b/flytekit/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/utils/loop_handling.py b/flytekit/utils/loop_handling.py new file mode 100644 index 0000000000..6aaf8a7c22 --- /dev/null +++ b/flytekit/utils/loop_handling.py @@ -0,0 +1,35 @@ +import asyncio +from signal import SIGINT, SIGTERM +from contextlib import contextmanager + +from concurrent.futures import ThreadPoolExecutor + +from flytekit.loggers import logger + + +def handler(loop, s: int): + loop.stop() + logger.debug(f"Shutting down loop at {id(loop)} via {s!s}") + loop.remove_signal_handler(SIGTERM) + loop.add_signal_handler(SIGINT, lambda: None) + + +@contextmanager +def use_event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + executor = ThreadPoolExecutor() + loop.set_default_executor(executor) + for sig in (SIGTERM, SIGINT): + loop.add_signal_handler(sig, handler, loop, sig) + try: + yield loop + finally: + tasks = asyncio.all_tasks(loop=loop) + for t in tasks: + logger.debug(f"canceling {t.get_name()}") + t.cancel() + group = asyncio.gather(*tasks, return_exceptions=True) + loop.run_until_complete(group) + executor.shutdown(wait=True) + loop.close() From 0d2a7ad4370e9af45fcfaf8f857a5b0676bf5d2a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 04:06:40 -0700 Subject: [PATCH 2/5] lint Signed-off-by: Yee Hing Tong --- flytekit/bin/entrypoint.py | 2 -- flytekit/utils/loop_handling.py | 5 ++--- tests/flytekit/unit/experimental/test_eager_workflows.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 692f8e594b..6916ed880e 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import datetime import inspect @@ -9,7 +8,6 @@ import sys import tempfile import traceback -import warnings from sys import exit from typing import Callable, List, Optional diff --git a/flytekit/utils/loop_handling.py b/flytekit/utils/loop_handling.py index 6aaf8a7c22..544abb6271 100644 --- a/flytekit/utils/loop_handling.py +++ b/flytekit/utils/loop_handling.py @@ -1,8 +1,7 @@ import asyncio -from signal import SIGINT, SIGTERM -from contextlib import contextmanager - from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from signal import SIGINT, SIGTERM from flytekit.loggers import logger diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py index 898d11a5ba..a7d516802b 100644 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -11,7 +11,7 @@ from flytekit import dynamic, task, workflow -from flytekit.bin.entrypoint import _get_working_loop, _dispatch_execute +from flytekit.bin.entrypoint import _dispatch_execute from flytekit.core import context_manager from flytekit.core.promise import VoidPromise from flytekit.exceptions.user import FlyteValidationException From c16b4d7eaa6f597bb9b428886b530b0985b59d59 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 04:10:19 -0700 Subject: [PATCH 3/5] remove test Signed-off-by: Yee Hing Tong --- .../unit/experimental/test_eager_workflows.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py index a7d516802b..732a15c198 100644 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -283,24 +283,3 @@ async def eager_wf_flyte_directory() -> str: assert result == "some data" -@mock.patch("flytekit.core.utils.load_proto_from_file") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") -@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") -@mock.patch("flytekit.core.utils.write_proto_to_file") -def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop): - """Test that event loop is preserved after executing eager workflow via dispatch.""" - - @eager - async def eager_wf(): - await asyncio.sleep(0.1) - return - - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) - ) - ) as ctx: - _dispatch_execute(ctx, lambda: eager_wf, "inputs path", "outputs prefix") - loop_after_execute = asyncio.get_event_loop_policy().get_event_loop() - assert event_loop == loop_after_execute From 1f1372267f90c9c9271877f53ebf0703acae8e80 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 05:01:42 -0700 Subject: [PATCH 4/5] lint Signed-off-by: Yee Hing Tong --- tests/flytekit/unit/experimental/test_eager_workflows.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py index 732a15c198..6903b8184a 100644 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -281,5 +281,3 @@ async def eager_wf_flyte_directory() -> str: result = asyncio.run(eager_wf_flyte_directory()) assert result == "some data" - - From 7f7d374d0508dec769ea2e6e9a7177e1944b883b Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 11 Sep 2024 18:01:05 -0700 Subject: [PATCH 5/5] wrap the loop around the whole dispatch execute Signed-off-by: Yee Hing Tong --- flytekit/bin/entrypoint.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 6916ed880e..a6e8df4cf1 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import datetime import inspect @@ -107,8 +108,8 @@ def _dispatch_execute( if inspect.iscoroutine(outputs): # Handle eager-mode (async) tasks logger.info("Output is a coroutine") - with use_event_loop() as loop: - outputs = loop.run_until_complete(outputs) + loop = asyncio.get_running_loop() + outputs = loop.run_until_complete(outputs) # Step3a if isinstance(outputs, VoidPromise): @@ -401,7 +402,8 @@ def load_task(): f"Test detected, returning. Args were {inputs} {output_prefix} {raw_output_data_prefix} {resolver} {resolver_args}" ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + with use_event_loop(): + _dispatch_execute(ctx, load_task, inputs, output_prefix) def _execute_map_task( @@ -462,7 +464,8 @@ def load_task(): ) return - _dispatch_execute(ctx, load_task, inputs, output_prefix) + with use_event_loop(): + _dispatch_execute(ctx, load_task, inputs, output_prefix) def normalize_inputs(