diff --git a/releasenotes/notes/improve-wait-exponential-jitter-e0a3a27fef39f97f.yaml b/releasenotes/notes/improve-wait-exponential-jitter-e0a3a27fef39f97f.yaml new file mode 100644 index 0000000..75736d1 --- /dev/null +++ b/releasenotes/notes/improve-wait-exponential-jitter-e0a3a27fef39f97f.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Improves wait_exponential_jitter to inherit wait_exponential + - Reuses `wait_exponential.__call__` method instead of duplicating code + - Follows argument names of wait_exponential + - Adds `min` argument + - Supports supplying `max`, `jitter`, `min` arguments as `timedelta` diff --git a/tenacity/wait.py b/tenacity/wait.py index dc3c850..1476a44 100644 --- a/tenacity/wait.py +++ b/tenacity/wait.py @@ -17,6 +17,7 @@ import abc import random import typing +import warnings from tenacity import _utils @@ -200,35 +201,41 @@ def __call__(self, retry_state: "RetryCallState") -> float: return random.uniform(self.min, high) -class wait_exponential_jitter(wait_base): +class wait_exponential_jitter(wait_exponential): """Wait strategy that applies exponential backoff and jitter. - It allows for a customized initial wait, maximum wait and jitter. + It allows for a customized multiplier, max wait, jitter and min wait. This implements the strategy described here: https://cloud.google.com/storage/docs/retry-strategy - The wait time is min(initial * 2**n + random.uniform(0, jitter), maximum) + The wait time is max(min, min(multiplier * 2**n + random.uniform(0, jitter), max)) where n is the retry count. """ def __init__( self, - initial: float = 1, - max: float = _utils.MAX_WAIT, # noqa - exp_base: float = 2, - jitter: float = 1, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + jitter: _utils.time_unit_type = 1, + min: _utils.time_unit_type = 0, # noqa + initial: typing.Union[int, float, None] = None, ) -> None: - self.initial = initial - self.max = max - self.exp_base = exp_base - self.jitter = jitter + if initial is not None and multiplier != 1: + msg = ( + "Received both `multiplier` and `initial` arguments. " + "`initial` is deprecated, use `multiplier` instead." + ) + raise ValueError(msg) + elif initial is not None: + msg = "`initial` is deprecated, use `multiplier` instead." + warnings.warn(msg, DeprecationWarning) + multiplier = initial + + super().__init__(multiplier, max, exp_base, min) + self.jitter = _utils.to_seconds(jitter) def __call__(self, retry_state: "RetryCallState") -> float: - jitter = random.uniform(0, self.jitter) - try: - exp = self.exp_base ** (retry_state.attempt_number - 1) - result = self.initial * exp + jitter - except OverflowError: - result = self.max - return max(0, min(result, self.max)) + result = super().__call__(retry_state) + random.uniform(0, self.jitter) + return max(self.min, min(result, self.max)) diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index b76fec2..666fd4f 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -523,9 +523,15 @@ def test_wait_exponential_jitter(self): self.assertEqual(fn(make_retry_state(8, 0)), 60) self.assertEqual(fn(make_retry_state(9, 0)), 60) - fn = tenacity.wait_exponential_jitter(10, 5) + max_wait = 5 + fn = tenacity.wait_exponential_jitter(10, max_wait) + for _ in range(1000): + self._assert_inclusive_range(fn(make_retry_state(1, 0)), 0, max_wait) + + min_wait = 5 + fn = tenacity.wait_exponential_jitter(min=min_wait) for _ in range(1000): - self.assertEqual(fn(make_retry_state(1, 0)), 5) + self._assert_inclusive_range(fn(make_retry_state(1, 0)), min_wait, 6) # Default arguments exist fn = tenacity.wait_exponential_jitter()