Skip to content

Commit

Permalink
Add taskmanager register shutdown task
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed Nov 25, 2024
1 parent a2b42c8 commit 30cc13b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
30 changes: 28 additions & 2 deletions ipv8/taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
import time
import traceback
import types
from asyncio import CancelledError, Future, Task, ensure_future, gather, get_running_loop, iscoroutinefunction, sleep
from asyncio import (
CancelledError,
Future,
Task,
ensure_future,
gather,
get_running_loop,
iscoroutinefunction,
sleep,
)
from contextlib import suppress
from functools import wraps
from threading import RLock
Expand Down Expand Up @@ -93,6 +102,7 @@ def __init__(self) -> None:
Create a new TaskManager and start the introspection loop.
"""
self._pending_tasks: WeakValueDictionary[Hashable, Future] = WeakValueDictionary()
self._shutdown_tasks: list[tuple[Callable | Coroutine, tuple[Any, ...], dict[str, Any]]] = []
self._task_lock = RLock()
self._shutdown = False
self._counter = 0
Expand Down Expand Up @@ -186,7 +196,8 @@ def done_cb(future: Future) -> None:
task.add_done_callback(done_cb)
return task

def register_anonymous_task(self, basename: str, task: Callable | Coroutine | Future, *args: Any, **kwargs) -> Future: # noqa: ANN401
def register_anonymous_task(self, basename: str, task: Callable | Coroutine | Future,
*args: Any, **kwargs) -> Future: # noqa: ANN401
"""
Wrapper for register_task to derive a unique name from the basename.
"""
Expand All @@ -207,6 +218,12 @@ def register_executor_task(self, name: str, func: Callable, *args: Any, # noqa:
return self.register_anonymous_task(name, future)
return self.register_task(name, future)

def register_shutdown_task(self, task: Callable | Coroutine, *args: Any, **kwargs) -> None: # noqa: ANN401
"""
Register a task to be run when this manager is shut down.
"""
self._shutdown_tasks.append((task, args, kwargs))

def cancel_pending_task(self, name: Hashable) -> Future:
"""
Cancels the named task.
Expand Down Expand Up @@ -273,6 +290,9 @@ async def shutdown_task_manager(self) -> None:
"""
Clear the task manager, cancel all pending tasks and disallow new tasks being added.
"""
if self._shutdown:
return

with self._task_lock:
self._shutdown = True
tasks = self.cancel_all_pending_tasks()
Expand All @@ -281,5 +301,11 @@ async def shutdown_task_manager(self) -> None:
with suppress(CancelledError):
await gather(*tasks)

for post_shutdown_task, args, kwargs in self._shutdown_tasks:
if iscoroutinefunction(post_shutdown_task):
await post_shutdown_task(*args, **kwargs)
elif callable(post_shutdown_task): # This is not necessary, but Mypy wants it here
post_shutdown_task(*args, **kwargs)


__all__ = ["TaskManager", "task"]
23 changes: 23 additions & 0 deletions ipv8/test/test_taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,29 @@ async def test_register_executor_task_anon(self) -> None:
_ = self.tm.register_executor_task("test", test, anon=True)
self.assertEqual(2, len(self.tm.get_tasks()))

async def test_register_shutdown_task(self) -> None:
"""
Check if registering a task on TaskManager shutdown works.
"""
sub_manager = TaskManager()
sub_fut = sub_manager.register_task("sub test", Future())
self.tm.register_shutdown_task(sub_manager.shutdown_task_manager)

await self.tm.shutdown_task_manager()

self.assertTrue(sub_fut.cancelled())

async def test_register_shutdown_function(self) -> None:
"""
Check if registering a plain function on TaskManager shutdown works.
"""
checker = ["test"]
self.tm.register_shutdown_task(checker.pop)

await self.tm.shutdown_task_manager()

self.assertEqual(0, len(checker))

async def test_get_task_existing_pending(self) -> None:
"""
Check if an existing pending task can be retrieved.
Expand Down

0 comments on commit 30cc13b

Please sign in to comment.