Skip to content

Commit

Permalink
Check for async strategies in sync context
Browse files Browse the repository at this point in the history
  • Loading branch information
hasier committed Jul 8, 2024
1 parent d5f8bff commit f8725d6
Showing 3 changed files with 81 additions and 26 deletions.
6 changes: 2 additions & 4 deletions tenacity/asyncio/retry.py
Original file line number Diff line number Diff line change
@@ -104,7 +104,7 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None
async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
result = False
for r in self.retries:
result = result or (await _utils.wrap_to_async_func(r)(retry_state) is True)
result = result or await _utils.wrap_to_async_func(r)(retry_state)
if result:
break
return result
@@ -119,9 +119,7 @@ def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None
async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override]
result = True
for r in self.retries:
result = result and (
await _utils.wrap_to_async_func(r)(retry_state) is True
)
result = result and await _utils.wrap_to_async_func(r)(retry_state)
if not result:
break
return result
41 changes: 39 additions & 2 deletions tenacity/retry.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,13 @@
import re
import typing

from . import _utils

try:
import tornado
except ImportError:
tornado = None

if typing.TYPE_CHECKING:
from tenacity import RetryCallState

@@ -283,7 +290,22 @@ def __init__(self, *retries: retry_base) -> None:
self.retries = retries

def __call__(self, retry_state: "RetryCallState") -> bool:
return any(r(retry_state) is True for r in self.retries)
result = False
for r in self.retries:
if _utils.is_coroutine_callable(r) or (
tornado
and hasattr(tornado.gen, "is_coroutine_function")
and tornado.gen.is_coroutine_function(r)
):
raise TypeError(
"Cannot use async functions in a sync context. Make sure "
"you use the correct retrying object and the corresponding "
"async strategies"
)
result = result or r(retry_state)
if result:
break
return result


class retry_all(retry_base):
@@ -293,4 +315,19 @@ def __init__(self, *retries: retry_base) -> None:
self.retries = retries

def __call__(self, retry_state: "RetryCallState") -> bool:
return all(r(retry_state) is True for r in self.retries)
result = True
for r in self.retries:
if _utils.is_coroutine_callable(r) or (
tornado
and hasattr(tornado.gen, "is_coroutine_function")
and tornado.gen.is_coroutine_function(r)
):
raise TypeError(
"Cannot use async functions in a sync context. Make sure "
"you use the correct retrying object and the corresponding "
"async strategies"
)
result = result and r(retry_state)
if not result:
break
return result
60 changes: 40 additions & 20 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
@@ -394,11 +394,16 @@ async def should_retry(retry_state: RetryCallState) -> bool:

return attempts

result = await test()

# It does not correctly work as the function is not called!
self.assertFalse(called)
self.assertEqual(3, result)
try:
await test()
except TypeError as exc:
self.assertEqual(
str(exc),
"Cannot use async functions in a sync context. Make sure you use "
"the correct retrying object and the corresponding async strategies",
)
else:
self.fail("This is an invalid retry combination that should have failed")

@asynctest
async def test_retry_with_async_result_ror(self):
@@ -517,11 +522,16 @@ async def should_retry(retry_state: RetryCallState) -> bool:

return attempts

result = await test()

# It does not correctly work as the function is not called!
self.assertFalse(called)
self.assertEqual(3, result)
try:
await test()
except TypeError as exc:
self.assertEqual(
str(exc),
"Cannot use async functions in a sync context. Make sure you use "
"the correct retrying object and the corresponding async strategies",
)
else:
self.fail("This is an invalid retry combination that should have failed")

@asynctest
async def test_retry_with_async_result_and(self):
@@ -623,11 +633,16 @@ async def should_retry(retry_state: RetryCallState) -> bool:

return attempts

result = await test()

# It does not correctly work as the function is not called!
self.assertFalse(called)
self.assertEqual(1, result)
try:
await test()
except TypeError as exc:
self.assertEqual(
str(exc),
"Cannot use async functions in a sync context. Make sure you use "
"the correct retrying object and the corresponding async strategies",
)
else:
self.fail("This is an invalid retry combination that should have failed")

@asynctest
async def test_retry_with_async_result_rand(self):
@@ -729,11 +744,16 @@ async def should_retry(retry_state: RetryCallState) -> bool:

return attempts

result = await test()

# It does not correctly work as the function is not called!
self.assertFalse(called)
self.assertEqual(1, result)
try:
await test()
except TypeError as exc:
self.assertEqual(
str(exc),
"Cannot use async functions in a sync context. Make sure you use "
"the correct retrying object and the corresponding async strategies",
)
else:
self.fail("This is an invalid retry combination that should have failed")

@asynctest
async def test_async_retying_iterator(self):

0 comments on commit f8725d6

Please sign in to comment.