From cf815009e9319c93d87cbec7ec277024f8f77d08 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 28 Nov 2024 10:11:37 +0800 Subject: [PATCH] feat(rate_limit): implement user rate limiting in tabby-webserver - Added `rate_limit` module to `ee/tabby-webserver/src/lib.rs`. - Updated `crates/http-api-bindings/Cargo.toml` and `Cargo.toml` to include `ratelimit` dependency. - Added `rate_limit.rs` to `ee/tabby-webserver/src/` with implementation for user rate limiting. - Configured rate limiters to allow 200 requests per minute per user. --- Cargo.lock | 2 ++ Cargo.toml | 1 + crates/http-api-bindings/Cargo.toml | 2 +- ee/tabby-webserver/Cargo.toml | 2 ++ ee/tabby-webserver/src/lib.rs | 1 + ee/tabby-webserver/src/rate_limit.rs | 39 +++++++++++++++++++++++++++ ee/tabby-webserver/src/service/mod.rs | 15 +++++++++++ 7 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 ee/tabby-webserver/src/rate_limit.rs diff --git a/Cargo.lock b/Cargo.lock index 011eb5f01e0b..424d65f6c69f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5282,6 +5282,7 @@ dependencies = [ "axum", "axum-extra", "bincode", + "cached", "chrono", "cron", "fs_extra", @@ -5300,6 +5301,7 @@ dependencies = [ "octocrab", "pin-project", "querystring", + "ratelimit", "reqwest", "rust-embed", "serde", diff --git a/Cargo.toml b/Cargo.toml index 7441bb8de7db..605e5269bbab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ logkit = "0.3" async-openai = "0.20" tracing-test = "0.2" clap = "4.3.0" +ratelimit = "0.10" [workspace.dependencies.uuid] version = "1.3.3" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 9ac1541ba882..7037c5958b46 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -18,7 +18,7 @@ tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } ollama-api-bindings = { path = "../ollama-api-bindings" } async-openai.workspace = true -ratelimit = "0.10" +ratelimit.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index cde4f85797dd..fb4ea15f1c72 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -54,6 +54,8 @@ cron = "0.12.1" async-stream.workspace = true logkit.workspace = true async-openai.workspace = true +ratelimit.workspace = true +cached = { workspace = true, features = ["async"] } [dev-dependencies] assert_matches.workspace = true diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 2e3352fb4165..e5fc5a9eb440 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -7,6 +7,7 @@ mod path; mod routes; mod service; mod webserver; +mod rate_limit; #[cfg(test)] pub use service::*; diff --git a/ee/tabby-webserver/src/rate_limit.rs b/ee/tabby-webserver/src/rate_limit.rs new file mode 100644 index 000000000000..88a65390368f --- /dev/null +++ b/ee/tabby-webserver/src/rate_limit.rs @@ -0,0 +1,39 @@ +use std::time::Duration; + +use cached::{Cached, TimedCache}; +use tokio::sync::Mutex; + +pub struct UserRateLimiter { + /// Mapping from user ID to rate limiter. + rate_limiters: Mutex>, +} + +static USER_REQUEST_LIMIT_PER_MINUTE: u64 = 200; + +impl Default for UserRateLimiter { + fn default() -> Self { + Self { + // User rate limiter is hardcoded to 200 requests per minute, thus the timespan is 60 seconds. + rate_limiters: Mutex::new(TimedCache::with_lifespan(60)), + } + } +} + +impl UserRateLimiter { + pub async fn is_allowed(&self, user_id: &str) -> bool { + let mut rate_limiters = self.rate_limiters.lock().await; + let rate_limiter = rate_limiters.cache_get_or_set_with(user_id.to_string(), || { + // Create a new rate limiter for this user. + ratelimit::Ratelimiter::builder(USER_REQUEST_LIMIT_PER_MINUTE, Duration::from_secs(60)) + .build() + .expect("Failed to create rate limiter") + }); + if let Err(_sleep) = rate_limiter.try_wait() { + // If the rate limiter is full, we return false. + false + } else { + // If the rate limiter is not full, we return true. + true + } + } +} diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 0cabb751ee20..4ccb2a23b4c3 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -58,6 +58,8 @@ use tabby_schema::{ AsID, AsRowid, CoreError, Result, ServiceLocator, }; +use crate::rate_limit::UserRateLimiter; + use self::{ analytic::new_analytic_service, email::new_email_service, license::new_license_service, }; @@ -83,6 +85,8 @@ struct ServerContext { code: Arc, setting: Arc, + + user_rate_limiter: UserRateLimiter, } impl ServerContext { @@ -153,6 +157,7 @@ impl ServerContext { user_group, access_policy, db_conn, + user_rate_limiter: UserRateLimiter::default(), } } @@ -213,6 +218,7 @@ impl WorkerService for ServerContext { let (auth, user) = self .authorize_request(request.uri(), request.headers()) .await; + let unauthorized = axum::response::Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Body::empty()) @@ -223,6 +229,15 @@ impl WorkerService for ServerContext { } if let Some(user) = user { + // Apply rate limiting when `user` is not none. + if !self.user_rate_limiter.is_allowed(&user).await { + return axum::response::Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .body(Body::empty()) + .unwrap() + .into_response(); + } + request.headers_mut().append( HeaderName::from_static(USER_HEADER_FIELD_NAME), HeaderValue::from_str(&user).expect("User must be valid header"),