Skip to content

Commit

Permalink
Avoid overwriting local contexts with retry decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
hasier committed Jun 24, 2024
1 parent ee6a8f7 commit 478dc03
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 6 deletions.
10 changes: 8 additions & 2 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,19 @@ def wraps(self, f: WrappedFn) -> WrappedFn:
f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any:
return self(f, *args, **kw)
# Always create a copy to prevent overwriting the local contexts when
# calling the same wrapped functions multiple times in the same stack
copy = self.copy()
wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
return copy(f, *args, **kw)

def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn:
return self.copy(*args, **kwargs).wraps(f)

wrapped_f.retry = self # type: ignore[attr-defined]
# Preserve attributes
wrapped_f.retry = wrapped_f # type: ignore[attr-defined]
wrapped_f.retry_with = retry_with # type: ignore[attr-defined]
wrapped_f.statistics = {} # type: ignore[attr-defined]

return wrapped_f # type: ignore[return-value]

Expand Down
13 changes: 9 additions & 4 deletions tenacity/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,23 @@ async def __anext__(self) -> AttemptManager:
raise StopAsyncIteration

def wraps(self, fn: WrappedFn) -> WrappedFn:
fn = super().wraps(fn)
wrapped = super().wraps(fn)
# Ensure wrapper is recognized as a coroutine function.

@functools.wraps(
fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
return await fn(*args, **kwargs)
# Always create a copy to prevent overwriting the local contexts when
# calling the same wrapped functions multiple times in the same stack
copy = self.copy()
async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
return await copy(fn, *args, **kwargs)

# Preserve attributes
async_wrapped.retry = fn.retry # type: ignore[attr-defined]
async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined]
async_wrapped.retry = async_wrapped # type: ignore[attr-defined]
async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined]
async_wrapped.statistics = {} # type: ignore[attr-defined]

return async_wrapped # type: ignore[return-value]

Expand Down
118 changes: 118 additions & 0 deletions tests/test_issue_478.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import asyncio
import typing
import unittest

from functools import wraps

from tenacity import RetryCallState, retry


def asynctest(
callable_: typing.Callable[..., typing.Any],
) -> typing.Callable[..., typing.Any]:
@wraps(callable_)
def wrapper(*a: typing.Any, **kw: typing.Any) -> typing.Any:
loop = asyncio.get_event_loop()
return loop.run_until_complete(callable_(*a, **kw))

return wrapper


MAX_RETRY_FIX_ATTEMPTS = 2


class TestIssue478(unittest.TestCase):
def test_issue(self) -> None:
results = []

def do_retry(retry_state: RetryCallState) -> bool:
outcome = retry_state.outcome
assert outcome
ex = outcome.exception()
_subject_: str = retry_state.args[0]

if _subject_ == "Fix": # no retry on fix failure
return False

if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
return False

if ex:
do_fix_work()
return True

return False

@retry(reraise=True, retry=do_retry)
def _do_work(subject: str) -> None:
if subject == "Error":
results.append(f"{subject} is not working")
raise Exception(f"{subject} is not working")
results.append(f"{subject} is working")

def do_any_work(subject: str) -> None:
_do_work(subject)

def do_fix_work() -> None:
_do_work("Fix")

try:
do_any_work("Error")
except Exception as exc:
assert str(exc) == "Error is not working"
else:
assert False, "No exception caught"

assert results == [
"Error is not working",
"Fix is working",
"Error is not working",
]

@asynctest
async def test_async(self) -> None:
results = []

async def do_retry(retry_state: RetryCallState) -> bool:
outcome = retry_state.outcome
assert outcome
ex = outcome.exception()
_subject_: str = retry_state.args[0]

if _subject_ == "Fix": # no retry on fix failure
return False

if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
return False

if ex:
await do_fix_work()
return True

return False

@retry(reraise=True, retry=do_retry)
async def _do_work(subject: str) -> None:
if subject == "Error":
results.append(f"{subject} is not working")
raise Exception(f"{subject} is not working")
results.append(f"{subject} is working")

async def do_any_work(subject: str) -> None:
await _do_work(subject)

async def do_fix_work() -> None:
await _do_work("Fix")

try:
await do_any_work("Error")
except Exception as exc:
assert str(exc) == "Error is not working"
else:
assert False, "No exception caught"

assert results == [
"Error is not working",
"Fix is working",
"Error is not working",
]

0 comments on commit 478dc03

Please sign in to comment.