From 21137e79ea6d59907b3b8d236c275c5190a612fe Mon Sep 17 00:00:00 2001 From: hasier Date: Wed, 12 Jun 2024 13:20:22 +0100 Subject: [PATCH 1/3] Add async strategies (#451) * Add async strategies * Fix init typing * Reuse is_coroutine_callable * Keep only async predicate overrides and DRY implementations * Ensure async and/or versions called when necessary * Run ruff format * Copy over strategies as async * Add release note --- .../add-async-actions-b249c527d99723bb.yaml | 5 + tenacity/__init__.py | 28 ++- tenacity/_utils.py | 12 ++ tenacity/{_asyncio.py => asyncio/__init__.py} | 86 ++++++--- tenacity/asyncio/retry.py | 125 +++++++++++++ tenacity/retry.py | 10 +- tests/test_asyncio.py | 165 +++++++++++++++++- 7 files changed, 396 insertions(+), 35 deletions(-) create mode 100644 releasenotes/notes/add-async-actions-b249c527d99723bb.yaml rename tenacity/{_asyncio.py => asyncio/__init__.py} (64%) create mode 100644 tenacity/asyncio/retry.py diff --git a/releasenotes/notes/add-async-actions-b249c527d99723bb.yaml b/releasenotes/notes/add-async-actions-b249c527d99723bb.yaml new file mode 100644 index 0000000..096a24f --- /dev/null +++ b/releasenotes/notes/add-async-actions-b249c527d99723bb.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Added the ability to use async functions for retries. This way, you can now use + asyncio coroutines for retry strategy predicates. diff --git a/tenacity/__init__.py b/tenacity/__init__.py index bcee3f5..7de36d4 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -24,7 +24,8 @@ import warnings from abc import ABC, abstractmethod from concurrent import futures -from inspect import iscoroutinefunction + +from . import _utils # Import all built-in retry strategies for easier usage. from .retry import retry_base # noqa @@ -87,6 +88,7 @@ if t.TYPE_CHECKING: import types + from . import asyncio as tasyncio from .retry import RetryBaseT from .stop import StopBaseT from .wait import WaitBaseT @@ -593,16 +595,24 @@ def retry(func: WrappedFn) -> WrappedFn: ... @t.overload def retry( - sleep: t.Callable[[t.Union[int, float]], t.Optional[t.Awaitable[None]]] = sleep, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep, stop: "StopBaseT" = stop_never, wait: "WaitBaseT" = wait_none(), - retry: "RetryBaseT" = retry_if_exception_type(), - before: t.Callable[["RetryCallState"], None] = before_nothing, - after: t.Callable[["RetryCallState"], None] = after_nothing, - before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None, + retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(), + before: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = before_nothing, + after: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = after_nothing, + before_sleep: t.Optional[ + t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] + ] = None, reraise: bool = False, retry_error_cls: t.Type["RetryError"] = RetryError, - retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None, + retry_error_callback: t.Optional[ + t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]] + ] = None, ) -> t.Callable[[WrappedFn], WrappedFn]: ... @@ -624,7 +634,7 @@ def wrap(f: WrappedFn) -> WrappedFn: f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)" ) r: "BaseRetrying" - if iscoroutinefunction(f): + if _utils.is_coroutine_callable(f): r = AsyncRetrying(*dargs, **dkw) elif ( tornado @@ -640,7 +650,7 @@ def wrap(f: WrappedFn) -> WrappedFn: return wrap -from tenacity._asyncio import AsyncRetrying # noqa:E402,I100 +from tenacity.asyncio import AsyncRetrying # noqa:E402,I100 if tornado: from tenacity.tornadoweb import TornadoRetrying diff --git a/tenacity/_utils.py b/tenacity/_utils.py index 4e34115..f11a088 100644 --- a/tenacity/_utils.py +++ b/tenacity/_utils.py @@ -87,3 +87,15 @@ def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool: partial_call = isinstance(call, functools.partial) and call.func dunder_call = partial_call or getattr(call, "__call__", None) return inspect.iscoroutinefunction(dunder_call) + + +def wrap_to_async_func( + call: typing.Callable[..., typing.Any], +) -> typing.Callable[..., typing.Awaitable[typing.Any]]: + if is_coroutine_callable(call): + return call + + async def inner(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + return call(*args, **kwargs) + + return inner diff --git a/tenacity/_asyncio.py b/tenacity/asyncio/__init__.py similarity index 64% rename from tenacity/_asyncio.py rename to tenacity/asyncio/__init__.py index b06303f..3ec0088 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/asyncio/__init__.py @@ -19,13 +19,29 @@ import sys import typing as t +import tenacity from tenacity import AttemptManager from tenacity import BaseRetrying from tenacity import DoAttempt from tenacity import DoSleep from tenacity import RetryCallState +from tenacity import RetryError +from tenacity import after_nothing +from tenacity import before_nothing from tenacity import _utils +# Import all built-in retry strategies for easier usage. +from .retry import RetryBaseT +from .retry import retry_all # noqa +from .retry import retry_any # noqa +from .retry import retry_if_exception # noqa +from .retry import retry_if_result # noqa +from ..retry import RetryBaseT as SyncRetryBaseT + +if t.TYPE_CHECKING: + from tenacity.stop import StopBaseT + from tenacity.wait import WaitBaseT + WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) @@ -38,15 +54,41 @@ def asyncio_sleep(duration: float) -> t.Awaitable[None]: class AsyncRetrying(BaseRetrying): - sleep: t.Callable[[float], t.Awaitable[t.Any]] - def __init__( self, - sleep: t.Callable[[float], t.Awaitable[t.Any]] = asyncio_sleep, - **kwargs: t.Any, + sleep: t.Callable[ + [t.Union[int, float]], t.Union[None, t.Awaitable[None]] + ] = asyncio_sleep, + stop: "StopBaseT" = tenacity.stop.stop_never, + wait: "WaitBaseT" = tenacity.wait.wait_none(), + retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = tenacity.retry_if_exception_type(), + before: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = before_nothing, + after: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = after_nothing, + before_sleep: t.Optional[ + t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] + ] = None, + reraise: bool = False, + retry_error_cls: t.Type["RetryError"] = RetryError, + retry_error_callback: t.Optional[ + t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]] + ] = None, ) -> None: - super().__init__(**kwargs) - self.sleep = sleep + super().__init__( + sleep=sleep, # type: ignore[arg-type] + stop=stop, + wait=wait, + retry=retry, # type: ignore[arg-type] + before=before, # type: ignore[arg-type] + after=after, # type: ignore[arg-type] + before_sleep=before_sleep, # type: ignore[arg-type] + reraise=reraise, + retry_error_cls=retry_error_cls, + retry_error_callback=retry_error_callback, + ) async def __call__( # type: ignore[override] self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any @@ -65,31 +107,21 @@ async def __call__( # type: ignore[override] retry_state.set_result(result) elif isinstance(do, DoSleep): retry_state.prepare_for_next_attempt() - await self.sleep(do) + await self.sleep(do) # type: ignore[misc] else: return do # type: ignore[no-any-return] - @classmethod - def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - if _utils.is_coroutine_callable(fn): - return fn - - async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: - return fn(*args, **kwargs) - - return inner - def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: - self.iter_state.actions.append(self._wrap_action_func(fn)) + self.iter_state.actions.append(_utils.wrap_to_async_func(fn)) async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] - self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)( + self.iter_state.retry_run_result = await _utils.wrap_to_async_func(self.retry)( retry_state ) async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] if self.wait: - sleep = await self._wrap_action_func(self.wait)(retry_state) + sleep = await _utils.wrap_to_async_func(self.wait)(retry_state) else: sleep = 0.0 @@ -97,7 +129,7 @@ async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignor async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start - self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)( + self.iter_state.stop_run_result = await _utils.wrap_to_async_func(self.stop)( retry_state ) @@ -127,7 +159,7 @@ async def __anext__(self) -> AttemptManager: return AttemptManager(retry_state=self._retry_state) elif isinstance(do, DoSleep): self._retry_state.prepare_for_next_attempt() - await self.sleep(do) + await self.sleep(do) # type: ignore[misc] else: raise StopAsyncIteration @@ -146,3 +178,13 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined] return async_wrapped # type: ignore[return-value] + + +__all__ = [ + "retry_all", + "retry_any", + "retry_if_exception", + "retry_if_result", + "WrappedFn", + "AsyncRetrying", +] diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py new file mode 100644 index 0000000..94b8b15 --- /dev/null +++ b/tenacity/asyncio/retry.py @@ -0,0 +1,125 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import typing + +from tenacity import _utils +from tenacity import retry_base + +if typing.TYPE_CHECKING: + from tenacity import RetryCallState + + +class async_retry_base(retry_base): + """Abstract base class for async retry strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + pass + + def __and__( # type: ignore[override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_all": + return retry_all(self, other) + + def __rand__( # type: ignore[misc,override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_all": + return retry_all(other, self) + + def __or__( # type: ignore[override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_any": + return retry_any(self, other) + + def __ror__( # type: ignore[misc,override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_any": + return retry_any(other, self) + + +RetryBaseT = typing.Union[ + async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]] +] + + +class retry_if_exception(async_retry_base): + """Retry strategy that retries if an exception verifies a predicate.""" + + def __init__( + self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]] + ) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if retry_state.outcome.failed: + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + else: + return False + + +class retry_if_result(async_retry_base): + """Retries if the result verifies a predicate.""" + + def __init__( + self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]] + ) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return await self.predicate(retry_state.outcome.result()) + else: + return False + + +class retry_any(async_retry_base): + """Retries if any of the retries condition is valid.""" + + def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + result = False + for r in self.retries: + result = result or await _utils.wrap_to_async_func(r)(retry_state) + if result: + break + return result + + +class retry_all(async_retry_base): + """Retries if all the retries condition are valid.""" + + def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + result = True + for r in self.retries: + result = result and await _utils.wrap_to_async_func(r)(retry_state) + if not result: + break + return result diff --git a/tenacity/retry.py b/tenacity/retry.py index c5e55a6..9211631 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -30,10 +30,16 @@ def __call__(self, retry_state: "RetryCallState") -> bool: pass def __and__(self, other: "retry_base") -> "retry_all": - return retry_all(self, other) + return other.__rand__(self) + + def __rand__(self, other: "retry_base") -> "retry_all": + return retry_all(other, self) def __or__(self, other: "retry_base") -> "retry_any": - return retry_any(self, other) + return other.__ror__(self) + + def __ror__(self, other: "retry_base") -> "retry_any": + return retry_any(other, self) RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], bool]] diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 24cf6ed..48f6286 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -22,8 +22,8 @@ import tenacity from tenacity import AsyncRetrying, RetryError -from tenacity import _asyncio as tasyncio -from tenacity import retry, retry_if_result, stop_after_attempt +from tenacity import asyncio as tasyncio +from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed from .test_tenacity import NoIOErrorAfterCount, current_time_ms @@ -202,6 +202,167 @@ def lt_3(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + async for attempt in tasyncio.AsyncRetrying( + retry=tasyncio.retry_if_result(lt_3) + ): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_exc(self): + async def test(): + attempts = 0 + + class CustomException(Exception): + pass + + async def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + async for attempt in tasyncio.AsyncRetrying( + retry=tasyncio.retry_if_exception(is_exc) + ): + with attempt: + attempts += 1 + if attempts < 3: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_or(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + class CustomException(Exception): + pass + + def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + retry_strategy = tasyncio.retry_if_result(lt_3) | retry_if_exception(is_exc) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + if 2 < attempts < 4: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(4, result) + + @asynctest + async def test_retry_with_async_result_ror(self): + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + class CustomException(Exception): + pass + + async def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + retry_strategy = retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + if 2 < attempts < 4: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(4, result) + + @asynctest + async def test_retry_with_async_result_and(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + def gt_0(x: float) -> bool: + return x > 0 + + retry_strategy = tasyncio.retry_if_result(lt_3) & retry_if_result(gt_0) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_rand(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + def gt_0(x: float) -> bool: + return x > 0 + + retry_strategy = retry_if_result(gt_0) & tasyncio.retry_if_result(lt_3) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5) From 5b00c1581a25d9777259f81562e9fc16f21c827e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 12:23:03 +0000 Subject: [PATCH 2/3] chore(deps): bump the github-actions group across 1 directory with 2 updates (#466) Bumps the github-actions group with 2 updates in the / directory: [actions/checkout](https://github.com/actions/checkout) and [actions/setup-python](https://github.com/actions/setup-python). Updates `actions/checkout` from 4.1.1 to 4.1.6 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v4.1.1...v4.1.6) Updates `actions/setup-python` from 5.0.0 to 5.1.0 - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v5.0.0...v5.1.0) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 4 ++-- .github/workflows/deploy.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fffad55..d4648a4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -34,12 +34,12 @@ jobs: tox: mypy steps: - name: Checkout 🛎️ - uses: actions/checkout@v4.1.1 + uses: actions/checkout@v4.1.6 with: fetch-depth: 0 - name: Setup Python 🔧 - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python }} allow-prereleases: true diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml index 05fb50a..8035202 100644 --- a/.github/workflows/deploy.yaml +++ b/.github/workflows/deploy.yaml @@ -11,12 +11,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout 🛎️ - uses: actions/checkout@v4.1.1 + uses: actions/checkout@v4.1.6 with: fetch-depth: 0 - name: Setup Python 🔧 - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: 3.11 From 952189b4e33c02b5cd3fb0eb82dd318087f06d66 Mon Sep 17 00:00:00 2001 From: Martin Beckert Date: Thu, 13 Jun 2024 10:46:08 +0200 Subject: [PATCH 3/3] Update index.rst: Remove * (#465) Co-authored-by: Julien Danjou --- doc/source/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/index.rst b/doc/source/index.rst index bdf7ff2..3f0764a 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -79,7 +79,7 @@ Examples Basic Retry ~~~~~~~~~~~ -.. testsetup:: * +.. testsetup:: import logging #