From 5bf7d21586072ce3b2e32db3252a28584dfc7540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez-Gallego?= Date: Thu, 12 Sep 2024 09:50:04 -0700 Subject: [PATCH] Add GatheringTaskGroup class for Python 3.10 and below --- src/sdsstools/utils.py | 46 ++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 8 +------- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/sdsstools/utils.py b/src/sdsstools/utils.py index 56d1ec6..55bd75d 100644 --- a/src/sdsstools/utils.py +++ b/src/sdsstools/utils.py @@ -23,6 +23,7 @@ "get_temporary_file_path", "run_in_executor", "cancel_task", + "GatheringTaskGroup", ] @@ -139,4 +140,49 @@ def results(self): return [task.result() for task in self.__tasks] +else: + + class GatheringTaskGroup: + """Simple implementation of ``asyncio.TaskGroup`` for Python 3.10 and below. + + The behaviour of this class is not exactly the same as ``asyncio.TaskGroup``, + especially when it comes to handling of exceptions during execution. + + """ + + def __init__(self): + self._tasks = [] + self._joined: bool = False + + def __repr__(self): + return f"" + + async def __aenter__(self): + self._joined = False + self._tasks = [] + + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type is not None: + raise RuntimeError("An error occurred in the task group.") + + await asyncio.gather(*self._tasks) + + def create_task(self, coro): + """Creates a task and appends it to the list of tasks.""" + + task = self.create_task(coro) + self._tasks.append(task) + + return task + + async def results(self): + """Returns the results of the tasks in the same order they were created.""" + + if self._joined: + raise RuntimeError("Tasks have not been gathered yet.") + + return [task.result() for task in self._tasks] + __all__.append("GatheringTaskGroup") diff --git a/test/test_utils.py b/test/test_utils.py index d634109..0c61212 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,13 +9,13 @@ from __future__ import annotations import asyncio -import sys import warnings from time import sleep import pytest from sdsstools.utils import ( + GatheringTaskGroup, Timer, cancel_task, get_temporary_file_path, @@ -23,10 +23,6 @@ ) -if sys.version_info >= (3, 11): - from sdsstools.utils import GatheringTaskGroup - - def test_timer(): with Timer() as timer: sleep(0.1) @@ -101,7 +97,6 @@ async def test_cancel_task_None(): await cancel_task(task) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") async def test_gathering_task_group(): async def _task(i): await asyncio.sleep(0.1) @@ -114,7 +109,6 @@ async def _task(i): assert group.results() == list(range(10)) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") async def test_gathering_task_group_results_fails(): async def _task(i): await asyncio.sleep(0.1)