From e5a5387870f7b689fc631ad46b8cc27de10411a5 Mon Sep 17 00:00:00 2001 From: Hasier Date: Thu, 4 Jul 2024 17:32:19 +0100 Subject: [PATCH] Add tests --- tenacity/asyncio/retry.py | 6 +- tenacity/retry.py | 4 +- tests/test_asyncio.py | 352 +++++++++++++++++++++++++++++++++++++- tests/test_tenacity.py | 110 ++++++++++++ 4 files changed, 467 insertions(+), 5 deletions(-) diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index 94b8b15..f458bad 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -104,7 +104,7 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] result = False for r in self.retries: - result = result or await _utils.wrap_to_async_func(r)(retry_state) + result = result or (await _utils.wrap_to_async_func(r)(retry_state) is True) if result: break return result @@ -119,7 +119,9 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] result = True for r in self.retries: - result = result and await _utils.wrap_to_async_func(r)(retry_state) + result = result and ( + await _utils.wrap_to_async_func(r)(retry_state) is True + ) if not result: break return result diff --git a/tenacity/retry.py b/tenacity/retry.py index 6584c5b..c05c16b 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -283,7 +283,7 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return any(r(retry_state) for r in self.retries) + return any(r(retry_state) is True for r in self.retries) class retry_all(retry_base): @@ -293,4 +293,4 @@ def __init__(self, *retries: retry_base) -> None: self.retries = retries def __call__(self, retry_state: "RetryCallState") -> bool: - return all(r(retry_state) for r in self.retries) + return all(r(retry_state) is True for r in self.retries) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 8716529..69ea883 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -28,7 +28,7 @@ import pytest import tenacity -from tenacity import AsyncRetrying, RetryError +from tenacity import AsyncRetrying, RetryCallState, RetryError from tenacity import asyncio as tasyncio from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed @@ -308,6 +308,98 @@ def is_exc(e: BaseException) -> bool: self.assertEqual(4, result) + @asynctest + async def test_retry_with_async_result_or_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tasyncio.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_or_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tasyncio.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_or_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tenacity.retry_if_result(lt_3) | should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(3, result) + @asynctest async def test_retry_with_async_result_ror(self): async def test(): @@ -339,6 +431,98 @@ async def is_exc(e: BaseException) -> bool: self.assertEqual(4, result) + @asynctest + async def test_retry_with_async_result_ror_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_ror_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_ror_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tenacity.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(3, result) + @asynctest async def test_retry_with_async_result_and(self): async def test(): @@ -362,6 +546,89 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_and_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tasyncio.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_and_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tasyncio.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_and_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tenacity.retry_if_result(lt_3) & should_retry # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(1, result) + @asynctest async def test_retry_with_async_result_rand(self): async def test(): @@ -385,6 +652,89 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_rand_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_rand_async_func(self): + async def test(): + attempts = 0 + called = False + + async def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tasyncio.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_sync_retry_with_async_result_rand_async_func(self): + called = False + + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + async def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tenacity.retry_if_result(lt_3) # type: ignore[operator] + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + # It does not correctly work as the function is not called! + self.assertFalse(called) + self.assertEqual(1, result) + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5) diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index e158fa6..07d4155 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -633,6 +633,58 @@ def r(fut): self.assertFalse(r(tenacity.Future.construct(1, 3, False))) self.assertFalse(r(tenacity.Future.construct(1, 1, True))) + async def test_retry_and_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = tenacity.retry_if_result(lt_3) & should_retry # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + + async def test_retry_rand_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return True + + retry_strategy = should_retry & tenacity.retry_if_result(lt_3) # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + def test_retry_or(self): retry = tenacity.retry_if_result( lambda x: x == "foo" @@ -647,6 +699,64 @@ def r(fut): self.assertFalse(r(tenacity.Future.construct(1, 2.2, False))) self.assertFalse(r(tenacity.Future.construct(1, 42, True))) + def test_retry_or_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = tenacity.retry_if_result(lt_3) | should_retry # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + + def test_retry_ror_func(self): + def test(): + attempts = 0 + called = False + + def lt_3(x: float) -> bool: + return x < 3 + + def should_retry(retry_state: RetryCallState) -> bool: + nonlocal called + called = True + return False + + retry_strategy = should_retry | tenacity.retry_if_result(lt_3) # type: ignore[operator] + for attempt in Retrying(retry=retry_strategy): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + self.assertTrue(called) + return attempts + + result = test() + + self.assertEqual(3, result) + def _raise_try_again(self): self._attempts += 1 if self._attempts < 3: