diff --git a/ipv8/taskmanager.py b/ipv8/taskmanager.py index ec7bd2d19..168bbac9a 100644 --- a/ipv8/taskmanager.py +++ b/ipv8/taskmanager.py @@ -4,7 +4,16 @@ import time import traceback import types -from asyncio import CancelledError, Future, Task, ensure_future, gather, get_running_loop, iscoroutinefunction, sleep +from asyncio import ( + CancelledError, + Future, + Task, + ensure_future, + gather, + get_running_loop, + iscoroutinefunction, + sleep, +) from contextlib import suppress from functools import wraps from threading import RLock @@ -93,6 +102,7 @@ def __init__(self) -> None: Create a new TaskManager and start the introspection loop. """ self._pending_tasks: WeakValueDictionary[Hashable, Future] = WeakValueDictionary() + self._shutdown_tasks: list[tuple[Callable | Coroutine, tuple[Any, ...], dict[str, Any]]] = [] self._task_lock = RLock() self._shutdown = False self._counter = 0 @@ -186,7 +196,8 @@ def done_cb(future: Future) -> None: task.add_done_callback(done_cb) return task - def register_anonymous_task(self, basename: str, task: Callable | Coroutine | Future, *args: Any, **kwargs) -> Future: # noqa: ANN401 + def register_anonymous_task(self, basename: str, task: Callable | Coroutine | Future, + *args: Any, **kwargs) -> Future: # noqa: ANN401 """ Wrapper for register_task to derive a unique name from the basename. """ @@ -207,6 +218,12 @@ def register_executor_task(self, name: str, func: Callable, *args: Any, # noqa: return self.register_anonymous_task(name, future) return self.register_task(name, future) + def register_shutdown_task(self, task: Callable | Coroutine, *args: Any, **kwargs) -> None: # noqa: ANN401 + """ + Register a task to be run when this manager is shut down. + """ + self._shutdown_tasks.append((task, args, kwargs)) + def cancel_pending_task(self, name: Hashable) -> Future: """ Cancels the named task. @@ -273,6 +290,9 @@ async def shutdown_task_manager(self) -> None: """ Clear the task manager, cancel all pending tasks and disallow new tasks being added. """ + if self._shutdown: + return + with self._task_lock: self._shutdown = True tasks = self.cancel_all_pending_tasks() @@ -281,5 +301,11 @@ async def shutdown_task_manager(self) -> None: with suppress(CancelledError): await gather(*tasks) + for post_shutdown_task, args, kwargs in self._shutdown_tasks: + if iscoroutinefunction(post_shutdown_task): + await post_shutdown_task(*args, **kwargs) + elif callable(post_shutdown_task): # This is not necessary, but Mypy wants it here + post_shutdown_task(*args, **kwargs) + __all__ = ["TaskManager", "task"] diff --git a/ipv8/test/test_taskmanager.py b/ipv8/test/test_taskmanager.py index 6163f1d44..899c4fcd9 100644 --- a/ipv8/test/test_taskmanager.py +++ b/ipv8/test/test_taskmanager.py @@ -291,6 +291,29 @@ async def test_register_executor_task_anon(self) -> None: _ = self.tm.register_executor_task("test", test, anon=True) self.assertEqual(2, len(self.tm.get_tasks())) + async def test_register_shutdown_task(self) -> None: + """ + Check if registering a task on TaskManager shutdown works. + """ + sub_manager = TaskManager() + sub_fut = sub_manager.register_task("sub test", Future()) + self.tm.register_shutdown_task(sub_manager.shutdown_task_manager) + + await self.tm.shutdown_task_manager() + + self.assertTrue(sub_fut.cancelled()) + + async def test_register_shutdown_function(self) -> None: + """ + Check if registering a plain function on TaskManager shutdown works. + """ + checker = ["test"] + self.tm.register_shutdown_task(checker.pop) + + await self.tm.shutdown_task_manager() + + self.assertEqual(0, len(checker)) + async def test_get_task_existing_pending(self) -> None: """ Check if an existing pending task can be retrieved.