Skip to content

Commit

Permalink
Very simple Retrier class
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Jan 2, 2024
1 parent dbae967 commit 75f10fd
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lvmopstools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@


__version__ = get_package_version(path=__file__, package_name="lvmopstools")


from .retrier import *
61 changes: 61 additions & 0 deletions src/lvmopstools/retrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego ([email protected])
# @Date: 2024-01-02
# @Filename: retrier.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

import asyncio
import time
from dataclasses import dataclass

from typing import Callable


__all__ = ["Retrier", "retrier"]


@dataclass
class Retrier:
"""A class that implements a retry mechanism."""

max_attempts: int = 3
delay: float = 0.01

def __call__(self, func: Callable):
"""Wraps a function to retry it if it fails."""

is_coroutine = asyncio.iscoroutinefunction(func)

def wrapper(*args, **kwargs):
attempt = 0
while True:
try:
return func(*args, **kwargs)
except Exception as ee:
attempt += 1
if attempt >= self.max_attempts:
raise ee
else:
time.sleep(self.delay)

async def async_wrapper(*args, **kwargs):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except Exception as ee:
attempt += 1
if attempt >= self.max_attempts:
raise ee
else:
await asyncio.sleep(self.delay)

return wrapper if not is_coroutine else async_wrapper


# Mostly to use as a decorator and maintain the standard that functions are lowercase.
retrier = Retrier
65 changes: 65 additions & 0 deletions tests/test_retrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego ([email protected])
# @Date: 2024-01-02
# @Filename: test_retrier.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

import pytest

from lvmopstools import retrier


@pytest.mark.parametrize("fail", [False, True])
async def test_retrier(fail: bool):
global n_attempts

n_attempts = 0

@retrier(max_attempts=3)
def test_function():
global n_attempts

if fail:
raise ValueError()

n_attempts += 1
if n_attempts == 2:
return True
else:
raise ValueError()

if fail:
with pytest.raises(ValueError):
test_function()
else:
assert test_function() is True


@pytest.mark.parametrize("fail", [False, True])
async def test_retrier_async(fail: bool):
global n_attempts

n_attempts = 0

@retrier(max_attempts=3)
async def test_function():
global n_attempts

if fail:
raise ValueError()

n_attempts += 1
if n_attempts == 2:
return True
else:
raise ValueError()

if fail:
with pytest.raises(ValueError):
await test_function()
else:
assert (await test_function()) is True

0 comments on commit 75f10fd

Please sign in to comment.