Skip to content

Commit

Permalink
Add GatheringTaskGroup class to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Jul 12, 2024
1 parent a58c991 commit df6436a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## [main](https://github.com/sdss/sdsstools/compare/1.7.1...main)

- Added `GatheringTaskGroup` class that extends the functionality of `asyncio.TaskGroup`.

## [1.7.1](https://github.com/sdss/sdsstools/compare/1.7.0...1.7.1) - 2024-07-02

- Support Numpy 2 (this only affects the `yanny` module).
Expand Down
36 changes: 35 additions & 1 deletion src/sdsstools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from functools import partial


__all__ = ["Timer", "get_temporary_file_path", "run_in_executor", "cancel_task"]
__all__ = [
"Timer",
"get_temporary_file_path",
"run_in_executor",
"cancel_task",
"GatheringTaskGroup",
]


class Timer:
Expand Down Expand Up @@ -102,3 +108,31 @@ async def cancel_task(task: asyncio.Future | None):
task.cancel()
with suppress(asyncio.CancelledError):
await task


class GatheringTaskGroup(asyncio.TaskGroup):
"""An extension to ``asyncio.TaskGroup`` that keeps track of the tasks created.
Adapted from https://stackoverflow.com/questions/75204560/consuming-taskgroup-response
"""

def __init__(self):
super().__init__()
self.__tasks = []

def create_task(self, coro, *, name=None, context=None):
"""Creates a task and appends it to the list of tasks."""

task = super().create_task(coro, name=name, context=context)
self.__tasks.append(task)

return task

def results(self):
"""Returns the results of the tasks in the same order they were created."""

if len(self._tasks) > 0:
raise RuntimeError("Not all tasks have completed yet.")

return [task.result() for task in self.__tasks]
33 changes: 32 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

import pytest

from sdsstools.utils import Timer, cancel_task, get_temporary_file_path, run_in_executor
from sdsstools.utils import (
GatheringTaskGroup,
Timer,
cancel_task,
get_temporary_file_path,
run_in_executor,
)


def test_timer():
Expand Down Expand Up @@ -89,3 +95,28 @@ async def test_cancel_task_done():
async def test_cancel_task_None():
task = None
await cancel_task(task)


async def test_gathering_task_group():
async def _task(i):
await asyncio.sleep(0.1)
return i

async with GatheringTaskGroup() as group:
for i in range(10):
group.create_task(_task(i))

assert group.results() == list(range(10))


async def test_gathering_task_group_results_fails():
async def _task(i):
await asyncio.sleep(0.1)
return i

async with GatheringTaskGroup() as group:
for i in range(10):
group.create_task(_task(i))

with pytest.raises(RuntimeError):
group.results()

0 comments on commit df6436a

Please sign in to comment.