diff --git a/ipv8/taskmanager.py b/ipv8/taskmanager.py index ec7bd2d19..01b8d6568 100644 --- a/ipv8/taskmanager.py +++ b/ipv8/taskmanager.py @@ -4,11 +4,23 @@ import time import traceback import types -from asyncio import CancelledError, Future, Task, ensure_future, gather, get_running_loop, iscoroutinefunction, sleep +from asyncio import ( + CancelledError, + Event, + Future, + Task, + ensure_future, + gather, + get_running_loop, + iscoroutinefunction, + shield, + sleep, +) +from collections.abc import Awaitable from contextlib import suppress from functools import wraps from threading import RLock -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast from weakref import WeakValueDictionary from .util import coroutine, succeed @@ -207,6 +219,47 @@ def register_executor_task(self, name: str, func: Callable, *args: Any, # noqa: return self.register_anonymous_task(name, future) return self.register_task(name, future) + async def register_shutdown_task(self, basename: str, task: Callable | Coroutine | Future, + *args: Any, **kwargs) -> Future: # noqa: ANN401 + """ + Register a task to be run when this manager is shut down. + """ + done_future: Future[bool] = Future() + + async def catch_shutdown() -> None: + """ + Wait until we are Cancelled and trigger if we are in shutdown mode. + We can't run any async work here: this will be scheduled AFTER shutdown, not during it. + """ + try: + await Event().wait() + finally: + if self._shutdown: + done_future.set_result(True) + else: + done_future.set_result(False) # Some odd crash + + async def after_cancel() -> None: + """ + This is a registered but uncancellable task. It will be awaited but cannot detect cancels. + """ + run_callback = await done_future + if run_callback: + fut = task + if callable(task): + fut = task(*args, **kwargs) + if not callable(task) or iscoroutinefunction(task): + await cast(Awaitable, fut) + + done_future.after_cancel_task = after_cancel() # type: ignore[attr-defined] + self.register_anonymous_task(f"[Catch shutdown] {basename}", catch_shutdown) + self.register_anonymous_task(f"[Run shutdown] {basename}", + shield(done_future.after_cancel_task)) # type: ignore[attr-defined] + + await sleep(0) # Enter both infinite loops + + return done_future + def cancel_pending_task(self, name: Hashable) -> Future: """ Cancels the named task. diff --git a/ipv8/test/test_taskmanager.py b/ipv8/test/test_taskmanager.py index 6163f1d44..98c4e982f 100644 --- a/ipv8/test/test_taskmanager.py +++ b/ipv8/test/test_taskmanager.py @@ -291,6 +291,19 @@ async def test_register_executor_task_anon(self) -> None: _ = self.tm.register_executor_task("test", test, anon=True) self.assertEqual(2, len(self.tm.get_tasks())) + async def test_register_shutdown_task(self) -> None: + """ + Check if registering a task on TaskManager shutdown works. + """ + sub_manager = TaskManager() + sub_fut = sub_manager.register_task("sub test", Future()) + fut = await self.tm.register_shutdown_task("test", sub_manager.shutdown_task_manager) + + await self.tm.shutdown_task_manager() + + self.assertTrue(fut.done()) + self.assertTrue(sub_fut.cancelled()) + async def test_get_task_existing_pending(self) -> None: """ Check if an existing pending task can be retrieved.