From 728c47fb3fd06c886aec2eb8bcacb468d36d3359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez-Gallego?= Date: Fri, 29 Nov 2024 13:05:38 -0800 Subject: [PATCH] Add with_timeout() to utils and add tests for utils.py --- CHANGELOG.md | 11 +++++++ pyproject.toml | 1 - src/lvmopstools/utils.py | 53 +++++++++++++++++++++++++++--- tests/test_utils.py | 71 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 131 insertions(+), 5 deletions(-) create mode 100644 tests/test_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 219c0db..fdb4efb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## Next version + +### 🚀 New + +* Add `with_timeout()` to utils. + +### ✨ Improved + +* Add test coverage for `utils.py`. + + ## 0.4.2 - November 27, 2024 ### ✨ Improved diff --git a/pyproject.toml b/pyproject.toml index 0d568db..0313314 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,6 @@ omit = [ "src/lvmopstools/ds9.py", "src/lvmopstools/kubernetes.py", "src/lvmopstools/influxdb.py", - "src/lvmopstools/utils.py", "src/lvmopstools/devices/specs.py", "src/lvmopstools/devices/nps.py" ] diff --git a/src/lvmopstools/utils.py b/src/lvmopstools/utils.py index e400aa2..6cb6d5f 100644 --- a/src/lvmopstools/utils.py +++ b/src/lvmopstools/utils.py @@ -10,15 +10,21 @@ import asyncio -from typing import Any +from typing import Any, Coroutine, TypeVar from clu import AMQPClient -__all__ = ["get_amqp_client", "get_exception_data", "stop_event_loop"] +__all__ = [ + "get_amqp_client", + "get_exception_data", + "stop_event_loop", + "with_timeout", + "is_notebook", +] -async def get_amqp_client(**kwargs) -> AMQPClient: +async def get_amqp_client(**kwargs) -> AMQPClient: # pragma: no cover """Returns a CLU AMQP client.""" amqp_client = AMQPClient(**kwargs) @@ -62,7 +68,7 @@ def get_exception_data(exception: Exception | None, traceback_frame: int = 0): return exception_data -async def stop_event_loop(timeout: float | None = 5): +async def stop_event_loop(timeout: float | None = 5): # pragma: no cover """Cancels all running tasks and stops the event loop.""" for task in asyncio.all_tasks(): @@ -93,3 +99,42 @@ def is_notebook() -> bool: return False # Other type (?) except NameError: return False # Probably standard Python interpreter + + +T = TypeVar("T", bound=Any) + + +async def with_timeout( + coro: Coroutine[Any, Any, T], + timeout: float | None, + raise_on_timeout: bool = True, +) -> T | None: + """Runs a coroutine with a timeout. + + Parameters + ---------- + coro + The coroutine to run. + timeout + The timeout in seconds. + raise_on_timeout + If :obj:`True`, raises a :class:`asyncio.TimeoutError` if the coroutine times + out, otherwise returns :obj:`None`. + + Returns + ------- + result + The result of the coroutine. + + Raises + ------ + asyncio.TimeoutError + If the coroutine times out. + + """ + + try: + return await asyncio.wait_for(coro, timeout) + except asyncio.TimeoutError: + if raise_on_timeout: + raise asyncio.TimeoutError(f"Timed out after {timeout} seconds.") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..8cea86e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# @Author: José Sánchez-Gallego (gallegoj@uw.edu) +# @Date: 2024-11-29 +# @Filename: test_utils.py +# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) + +from __future__ import annotations + +import asyncio + +import pytest +import pytest_mock + +import lvmopstools.utils +from lvmopstools.utils import is_notebook, with_timeout + + +async def _timeout(delay: float): + await asyncio.sleep(delay) + return True + + +async def test_with_timeout(): + with pytest.raises(asyncio.TimeoutError): + await with_timeout(_timeout(0.5), timeout=0.1) + + +async def test_with_timeout_no_raise(): + result = await with_timeout(_timeout(0.5), timeout=0.1, raise_on_timeout=False) + assert result is None + + +class GetPythonMocker: + def __init__(self, shell: str): + self.shell = shell + self.__class__.__name__ = shell + + def __call__(self): + return self + + +@pytest.mark.parametrize( + "shell, result", + [ + ("ZMQInteractiveShell", True), + ("TerminalInteractiveShell", False), + ("other", False), + ], +) +async def test_is_notebook(shell: str, result: bool, mocker: pytest_mock.MockerFixture): + mocker.patch.object( + lvmopstools.utils, + "get_ipython", + return_value=GetPythonMocker(shell), + create=True, + ) + + assert is_notebook() == result + + +async def test_is_notebook_name_Error(mocker: pytest_mock.MockerFixture): + mocker.patch.object( + lvmopstools.utils, + "get_ipython", + side_effect=NameError, + create=True, + ) + + assert not is_notebook()