Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
squadgazzz committed Nov 22, 2023
1 parent 4c9c8e9 commit cf6e0dc
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 14 deletions.
6 changes: 3 additions & 3 deletions crates/shared/src/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ pub struct RateLimitingStrategy {
drop_requests_until: Instant,
/// How many requests got rate limited in a row.
times_rate_limited: u64,
back_off_growth_factor: f64,
min_back_off: Duration,
max_back_off: Duration,
pub back_off_growth_factor: f64,
pub min_back_off: Duration,
pub max_back_off: Duration,
}

impl Default for RateLimitingStrategy {
Expand Down
128 changes: 117 additions & 11 deletions crates/solvers/src/boundary/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use {
anyhow::{ensure, Context, Result},
shared::rate_limiter::{
RateLimiter as SharedRateLimiter,
RateLimiterError as SharedRateLimiterError,
RateLimitingStrategy as SharedRateLimitingStrategy,
},
std::{future::Future, str::FromStr, time::Duration},
std::{future::Future, ops::Add, str::FromStr, time::Duration},
thiserror::Error,
};

Expand Down Expand Up @@ -52,7 +53,7 @@ impl FromStr for RateLimitingStrategy {
}
}

#[derive(Error, Debug, Clone, Default)]
#[derive(Error, Debug, Clone, Default, PartialEq)]
pub enum RateLimiterError {
#[default]
#[error("rate limited")]
Expand All @@ -75,7 +76,9 @@ impl RateLimiter {
self.inner
.execute(task, requires_back_off)
.await
.map_err(|_| RateLimiterError::RateLimited)
.map_err(|err| match err {
SharedRateLimiterError::RateLimited => RateLimiterError::RateLimited,
})
}

pub async fn execute_with_retries<T, F, Fut>(
Expand All @@ -89,19 +92,122 @@ impl RateLimiter {
{
let mut retries = 0;
while retries < self.max_retries {
match self.execute(task(), requires_back_off.clone()).await {
Ok(result) => return Ok(result),
Err(RateLimiterError::RateLimited) => {
let back_off_duration = self.get_back_off_duration();
tokio::time::sleep(back_off_duration).await;
retries += 1;
}
let result = self.execute(task(), requires_back_off.clone()).await;
let should_retry = match &result {
Ok(result) => requires_back_off.clone()(result),
Err(RateLimiterError::RateLimited) => true,
};

if should_retry {
let back_off_duration = self.get_back_off_duration();
tokio::time::sleep(back_off_duration).await;
retries += 1;
} else {
return result;
}
}
Err(RateLimiterError::RateLimited)
}

fn get_back_off_duration(&self) -> Duration {
self.inner.strategy.lock().unwrap().get_current_back_off()
self.inner
.strategy
.lock()
.unwrap()
.get_current_back_off()
// add 100 millis to make sure the RateLimiter updated it's counter
.add(Duration::from_millis(100))
}
}

#[cfg(test)]
mod tests {
use {
super::*,
std::sync::atomic::{AtomicUsize, Ordering},
};

#[tokio::test]
async fn test_execute_with_retries() {
let strategy = RateLimitingStrategy::default();
let rate_limiter = RateLimiter::new(strategy, "test".to_string());
let call_count = AtomicUsize::new(0);

let task = || {
let count = call_count.fetch_add(1, Ordering::SeqCst);
async move {
if count < 1 {
Err(RateLimiterError::RateLimited)
} else {
Ok(42)
}
}
};

let result = rate_limiter
.execute_with_retries(task, |res| {
let back_off_required = matches!(res, Err(RateLimiterError::RateLimited));
back_off_required
})
.await
.and_then(|result: Result<i32, RateLimiterError>| result);
assert_eq!(result, Ok(42));
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}

#[tokio::test]
async fn test_execute_with_retries_exceeds() {
let strategy = RateLimitingStrategy::default();
let rate_limiter = RateLimiter::new(strategy, "test".to_string());
let call_count = AtomicUsize::new(0);

let task = || {
call_count.fetch_add(1, Ordering::SeqCst);
async move { Err(RateLimiterError::RateLimited) }
};

let result = rate_limiter
.execute_with_retries(task, |res| {
let back_off_required = matches!(res, Err(RateLimiterError::RateLimited));
back_off_required
})
.await
.and_then(|result: Result<i32, RateLimiterError>| result);
assert_eq!(result, Err(RateLimiterError::RateLimited));
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
}

#[cfg(test)]
mod config_tests {
use super::*;

#[test]
fn parse_rate_limiting_strategy() {
let config_str = "1.5,10,30,3";
let strategy: RateLimitingStrategy = config_str.parse().unwrap();
assert_eq!(strategy.inner.back_off_growth_factor, 1.5);
assert_eq!(strategy.inner.min_back_off, Duration::from_secs(10));
assert_eq!(strategy.inner.max_back_off, Duration::from_secs(30));
assert_eq!(strategy.max_retries, 3);
}

#[test]
fn parse_rate_limiting_strategy_with_default_retries() {
let config_str = "1.5,10,30";
let strategy: RateLimitingStrategy = config_str.parse().unwrap();
assert_eq!(strategy.max_retries, DEFAULT_MAX_RETIRES);
}

#[test]
fn parse_invalid_rate_limiting_strategy() {
let config_str = "invalid";
assert!(config_str.parse::<RateLimitingStrategy>().is_err());
}

#[test]
fn parse_too_many_args_rate_limiting_strategy() {
let config_str = "1.5,10,30,3,10";
assert!(config_str.parse::<RateLimitingStrategy>().is_err());
}
}

0 comments on commit cf6e0dc

Please sign in to comment.