From cf6e0dcf40da9bb7e6b8e9d3e0ce3406d867efc1 Mon Sep 17 00:00:00 2001 From: ilya Date: Wed, 22 Nov 2023 20:54:36 +0000 Subject: [PATCH] Tests --- crates/shared/src/rate_limiter.rs | 6 +- crates/solvers/src/boundary/rate_limiter.rs | 128 ++++++++++++++++++-- 2 files changed, 120 insertions(+), 14 deletions(-) diff --git a/crates/shared/src/rate_limiter.rs b/crates/shared/src/rate_limiter.rs index 34d5035f6e..e3e67200f9 100644 --- a/crates/shared/src/rate_limiter.rs +++ b/crates/shared/src/rate_limiter.rs @@ -33,9 +33,9 @@ pub struct RateLimitingStrategy { drop_requests_until: Instant, /// How many requests got rate limited in a row. times_rate_limited: u64, - back_off_growth_factor: f64, - min_back_off: Duration, - max_back_off: Duration, + pub back_off_growth_factor: f64, + pub min_back_off: Duration, + pub max_back_off: Duration, } impl Default for RateLimitingStrategy { diff --git a/crates/solvers/src/boundary/rate_limiter.rs b/crates/solvers/src/boundary/rate_limiter.rs index a962595369..3b60d15cb1 100644 --- a/crates/solvers/src/boundary/rate_limiter.rs +++ b/crates/solvers/src/boundary/rate_limiter.rs @@ -2,9 +2,10 @@ use { anyhow::{ensure, Context, Result}, shared::rate_limiter::{ RateLimiter as SharedRateLimiter, + RateLimiterError as SharedRateLimiterError, RateLimitingStrategy as SharedRateLimitingStrategy, }, - std::{future::Future, str::FromStr, time::Duration}, + std::{future::Future, ops::Add, str::FromStr, time::Duration}, thiserror::Error, }; @@ -52,7 +53,7 @@ impl FromStr for RateLimitingStrategy { } } -#[derive(Error, Debug, Clone, Default)] +#[derive(Error, Debug, Clone, Default, PartialEq)] pub enum RateLimiterError { #[default] #[error("rate limited")] @@ -75,7 +76,9 @@ impl RateLimiter { self.inner .execute(task, requires_back_off) .await - .map_err(|_| RateLimiterError::RateLimited) + .map_err(|err| match err { + SharedRateLimiterError::RateLimited => RateLimiterError::RateLimited, + }) } pub async fn execute_with_retries( @@ -89,19 +92,122 @@ impl RateLimiter { { let mut retries = 0; while retries < self.max_retries { - match self.execute(task(), requires_back_off.clone()).await { - Ok(result) => return Ok(result), - Err(RateLimiterError::RateLimited) => { - let back_off_duration = self.get_back_off_duration(); - tokio::time::sleep(back_off_duration).await; - retries += 1; - } + let result = self.execute(task(), requires_back_off.clone()).await; + let should_retry = match &result { + Ok(result) => requires_back_off.clone()(result), + Err(RateLimiterError::RateLimited) => true, + }; + + if should_retry { + let back_off_duration = self.get_back_off_duration(); + tokio::time::sleep(back_off_duration).await; + retries += 1; + } else { + return result; } } Err(RateLimiterError::RateLimited) } fn get_back_off_duration(&self) -> Duration { - self.inner.strategy.lock().unwrap().get_current_back_off() + self.inner + .strategy + .lock() + .unwrap() + .get_current_back_off() + // add 100 millis to make sure the RateLimiter updated it's counter + .add(Duration::from_millis(100)) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + std::sync::atomic::{AtomicUsize, Ordering}, + }; + + #[tokio::test] + async fn test_execute_with_retries() { + let strategy = RateLimitingStrategy::default(); + let rate_limiter = RateLimiter::new(strategy, "test".to_string()); + let call_count = AtomicUsize::new(0); + + let task = || { + let count = call_count.fetch_add(1, Ordering::SeqCst); + async move { + if count < 1 { + Err(RateLimiterError::RateLimited) + } else { + Ok(42) + } + } + }; + + let result = rate_limiter + .execute_with_retries(task, |res| { + let back_off_required = matches!(res, Err(RateLimiterError::RateLimited)); + back_off_required + }) + .await + .and_then(|result: Result| result); + assert_eq!(result, Ok(42)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_execute_with_retries_exceeds() { + let strategy = RateLimitingStrategy::default(); + let rate_limiter = RateLimiter::new(strategy, "test".to_string()); + let call_count = AtomicUsize::new(0); + + let task = || { + call_count.fetch_add(1, Ordering::SeqCst); + async move { Err(RateLimiterError::RateLimited) } + }; + + let result = rate_limiter + .execute_with_retries(task, |res| { + let back_off_required = matches!(res, Err(RateLimiterError::RateLimited)); + back_off_required + }) + .await + .and_then(|result: Result| result); + assert_eq!(result, Err(RateLimiterError::RateLimited)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } +} + +#[cfg(test)] +mod config_tests { + use super::*; + + #[test] + fn parse_rate_limiting_strategy() { + let config_str = "1.5,10,30,3"; + let strategy: RateLimitingStrategy = config_str.parse().unwrap(); + assert_eq!(strategy.inner.back_off_growth_factor, 1.5); + assert_eq!(strategy.inner.min_back_off, Duration::from_secs(10)); + assert_eq!(strategy.inner.max_back_off, Duration::from_secs(30)); + assert_eq!(strategy.max_retries, 3); + } + + #[test] + fn parse_rate_limiting_strategy_with_default_retries() { + let config_str = "1.5,10,30"; + let strategy: RateLimitingStrategy = config_str.parse().unwrap(); + assert_eq!(strategy.max_retries, DEFAULT_MAX_RETIRES); + } + + #[test] + fn parse_invalid_rate_limiting_strategy() { + let config_str = "invalid"; + assert!(config_str.parse::().is_err()); + } + + #[test] + fn parse_too_many_args_rate_limiting_strategy() { + let config_str = "1.5,10,30,3,10"; + assert!(config_str.parse::().is_err()); } }