diff --git a/Cargo.lock b/Cargo.lock index 4573053379fd..4befd5e13083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,7 +308,7 @@ dependencies = [ "sync_wrapper 1.0.1", "tokio", "tokio-tungstenite", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -352,7 +352,7 @@ dependencies = [ "mime", "pin-project-lite", "serde", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -376,7 +376,7 @@ dependencies = [ "once_cell", "pin-project", "tokio", - "tower", + "tower 0.4.13", "tower-http", ] @@ -1900,6 +1900,7 @@ dependencies = [ "tabby-common", "tabby-inference", "tokio", + "tower 0.5.1", ] [[package]] @@ -2061,7 +2062,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", + "tower 0.4.13", "tower-service", "tracing", ] @@ -3068,7 +3069,7 @@ dependencies = [ "serde_urlencoded", "snafu", "tokio", - "tower", + "tower 0.4.13", "tower-http", "tracing", "url", @@ -5772,6 +5773,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.5.2" @@ -5787,7 +5805,7 @@ dependencies = [ "iri-string", "pin-project-lite", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -5795,15 +5813,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" diff --git a/Cargo.toml b/Cargo.toml index 7441bb8de7db..17fcc3731f5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ url = "2.5.0" temp_testdir = "0.2" git2 = "0.18.3" tower-http = "0.5" +tower = "0.5" mime_guess = "2.0.4" assert_matches = "1.5" insta = "1.34.0" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 705fdfb152df..f6ad5644fee9 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -18,6 +18,8 @@ tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } ollama-api-bindings = { path = "../ollama-api-bindings" } async-openai.workspace = true +tower = { workspace = true , features = ["limit", "util", "buffer"] } +tokio = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index fc42d123ec05..d5b8920af72f 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,18 +1,22 @@ mod llama; mod openai; +mod rate_limit; mod voyage; use core::panic; use std::sync::Arc; use llama::LlamaCppEngine; +use rate_limit::RateLimitedEmbedding; use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; pub async fn create(config: &HttpModelConfig) -> Arc { - match config.kind.as_str() { + let rpm = config.rate_limit.request_per_minute; + + let embedding: Arc = match config.kind.as_str() { "llama.cpp/embedding" => { let engine = LlamaCppEngine::create( config @@ -53,5 +57,9 @@ pub async fn create(config: &HttpModelConfig) -> Arc { "Unsupported kind for http embedding model: {}", unsupported_kind ), - } + }; + + Arc::new( + RateLimitedEmbedding::new(embedding, rpm).expect("Failed to create rate limited embedding"), + ) } diff --git a/crates/http-api-bindings/src/embedding/rate_limit.rs b/crates/http-api-bindings/src/embedding/rate_limit.rs new file mode 100644 index 000000000000..0f4df7c4ddb2 --- /dev/null +++ b/crates/http-api-bindings/src/embedding/rate_limit.rs @@ -0,0 +1,63 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, + time, +}; + +use async_trait::async_trait; +use futures::future::BoxFuture; +use tabby_inference::Embedding; +use tokio::sync::Mutex; +use tower::{Service, ServiceBuilder, ServiceExt}; + +struct EmbeddingService { + embedding: Arc, +} + +impl Service for EmbeddingService { + type Response = Vec; + type Error = anyhow::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, prompt: String) -> Self::Future { + let embedding = self.embedding.clone(); + Box::pin(async move { embedding.embed(&prompt).await }) + } +} + +pub struct RateLimitedEmbedding { + embedding: Arc, anyhow::Error>>>, +} + +impl RateLimitedEmbedding { + pub fn new(embedding: Arc, rpm: u64) -> anyhow::Result { + if rpm == 0 { + anyhow::bail!( + "Can not create rate limited embedding client with 0 requests per minute" + ); + } + + let service = ServiceBuilder::new() + .rate_limit(rpm, time::Duration::from_secs(60)) + .service(EmbeddingService { embedding }) + .boxed(); + + Ok(Self { + embedding: Arc::new(Mutex::new(service)), + }) + } +} + +#[async_trait] +impl Embedding for RateLimitedEmbedding { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + let mut service = self.embedding.lock().await; + let prompt_owned = prompt.to_string(); + let response = service.ready().await?.call(prompt_owned).await?; + Ok(response) + } +} diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 7f0df26a1349..4d32b4221904 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -289,6 +289,10 @@ pub struct HttpModelConfig { #[builder(default)] pub api_key: Option, + #[builder(default)] + #[serde(default)] + pub rate_limit: RateLimit, + /// Used by OpenAI style API for model name. #[builder(default)] pub model_name: Option, @@ -309,6 +313,20 @@ pub struct HttpModelConfig { pub additional_stop_words: Option>, } +#[derive(Serialize, Deserialize, Builder, Debug, Clone)] +pub struct RateLimit { + // The limited number of requests can be made in one minute. + pub request_per_minute: u64, +} + +impl Default for RateLimit { + fn default() -> Self { + Self { + request_per_minute: 600, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct LocalModelConfig { pub model_id: String,