Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async version of retry_on_exceptions_with_backoff utility #16374

73 changes: 73 additions & 0 deletions llama-index-core/llama_index/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,79 @@ def retry_on_exceptions_with_backoff(
backoff_secs = min(backoff_secs * 2, max_backoff_secs)


async def aretry_on_exceptions_with_backoff(
async_fn: Callable,
errors_to_retry: List[ErrorToRetry],
max_tries: int = 10,
min_backoff_secs: float = 0.5,
max_backoff_secs: float = 60.0,
) -> Any:
"""Execute lambda function with retries and exponential backoff.

Args:
async_fn (Callable): Async Function to be called and output we want.
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
At least one needs to be provided.
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
min_backoff_secs (float): Minimum amount of backoff time between attempts.
Defaults to 0.5.
max_backoff_secs (float): Maximum amount of backoff time between attempts.
Defaults to 60.

"""
if not errors_to_retry:
raise ValueError("At least one error to retry needs to be provided")

error_checks = {
error_to_retry.exception_cls: error_to_retry.check_fn
for error_to_retry in errors_to_retry
}
exception_class_tuples = tuple(error_checks.keys())

backoff_secs = min_backoff_secs
tries = 0

while True:
try:
return await async_fn()
except exception_class_tuples as e:
traceback.print_exc()
tries += 1
if tries >= max_tries:
raise
check_fn = error_checks.get(e.__class__)
if check_fn and not check_fn(e):
raise
time.sleep(backoff_secs)
backoff_secs = min(backoff_secs * 2, max_backoff_secs)


def get_retry_on_exceptions_with_backoff_decorator(
*retry_args: Any, **retry_kwargs: Any
) -> Callable:
"""Return a decorator that retries with exponential backoff on provided exceptions."""

def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*func_args: Any, **func_kwargs: Any) -> Any:
return retry_on_exceptions_with_backoff(
lambda: func(*func_args, **func_kwargs), *retry_args, **retry_kwargs
)

@wraps(func)
async def awrapper(*func_args: Any, **func_kwargs: Any) -> Any:
async def foo() -> Any:
return await func(*func_args, **func_kwargs)

return await aretry_on_exceptions_with_backoff(
foo, *retry_args, **retry_kwargs
)

return awrapper if asyncio.iscoroutinefunction(func) else wrapper

return decorator


def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
if len(text) <= max_length:
Expand Down
70 changes: 70 additions & 0 deletions llama-index-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
iter_batch,
print_text,
retry_on_exceptions_with_backoff,
get_retry_on_exceptions_with_backoff_decorator,
)


Expand Down Expand Up @@ -81,6 +82,75 @@ def test_retry_on_exceptions_with_backoff() -> None:
assert call_count == 1


@pytest.mark.asyncio()
async def test_retry_on_exceptions_with_backoff_decorator() -> None:
"""Make sure retry decorator works for both sync and async functions."""
global call_count
call_count = 0

retry_on_value_error = get_retry_on_exceptions_with_backoff_decorator(
[ErrorToRetry(ValueError)],
max_tries=3,
min_backoff_secs=0.0,
)

SUCCESS_MESSAGE = "done"

@retry_on_value_error
def fn_with_exception(exception, n=2) -> None:
global call_count
call_count += 1
if call_count >= n:
return SUCCESS_MESSAGE
raise exception

@retry_on_value_error
async def async_fn_with_exception(exception, n=2) -> None:
global call_count
call_count += 1
if call_count >= n:
return SUCCESS_MESSAGE
raise exception

# sync function
# should retry 3 times
call_count = 0
with pytest.raises(ValueError):
result = fn_with_exception(ValueError, 5)
assert call_count == 3

# should not raise exception
call_count = 0
result = fn_with_exception(ValueError, 2)
assert result == SUCCESS_MESSAGE
assert call_count == 2

# different exception will not get retried
call_count = 0
with pytest.raises(TypeError):
result = fn_with_exception(TypeError, 2)
assert call_count == 1

# Async function
# should retry 3 times
call_count = 0
with pytest.raises(ValueError):
result = await async_fn_with_exception(ValueError, 5)
assert call_count == 3

# should not raise exception
call_count = 0
result = await async_fn_with_exception(ValueError, 2)
assert result == SUCCESS_MESSAGE
assert call_count == 2

# different exception will not get retried
call_count = 0
with pytest.raises(TypeError):
result = await async_fn_with_exception(TypeError, 2)
assert call_count == 1


def test_retry_on_conditional_exceptions() -> None:
"""Make sure retry function works on conditional exceptions."""
global call_count
Expand Down
Loading