From 268345ba83076bfd3bead4ecb773c6b1ffefa956 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Fri, 22 Nov 2024 18:12:40 +0800 Subject: [PATCH 01/12] feat(models-http-api): add rate limit in embedding api --- Cargo.lock | 38 +++++++--- Cargo.toml | 1 + crates/http-api-bindings/Cargo.toml | 2 + .../http-api-bindings/src/embedding/llama.rs | 28 +++---- crates/http-api-bindings/src/embedding/mod.rs | 7 ++ crates/http-api-bindings/src/lib.rs | 75 +++++++++++++++++++ crates/tabby-common/src/config.rs | 11 +++ 7 files changed, 138 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4573053379fd..4befd5e13083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,7 +308,7 @@ dependencies = [ "sync_wrapper 1.0.1", "tokio", "tokio-tungstenite", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -352,7 +352,7 @@ dependencies = [ "mime", "pin-project-lite", "serde", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -376,7 +376,7 @@ dependencies = [ "once_cell", "pin-project", "tokio", - "tower", + "tower 0.4.13", "tower-http", ] @@ -1900,6 +1900,7 @@ dependencies = [ "tabby-common", "tabby-inference", "tokio", + "tower 0.5.1", ] [[package]] @@ -2061,7 +2062,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", + "tower 0.4.13", "tower-service", "tracing", ] @@ -3068,7 +3069,7 @@ dependencies = [ "serde_urlencoded", "snafu", "tokio", - "tower", + "tower 0.4.13", "tower-http", "tracing", "url", @@ -5772,6 +5773,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.5.2" @@ -5787,7 +5805,7 @@ dependencies = [ "iri-string", "pin-project-lite", "tokio", - "tower", + "tower 0.4.13", "tower-layer", "tower-service", "tracing", @@ -5795,15 +5813,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" diff --git a/Cargo.toml b/Cargo.toml index 7441bb8de7db..17fcc3731f5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ url = "2.5.0" temp_testdir = "0.2" git2 = "0.18.3" tower-http = "0.5" +tower = "0.5" mime_guess = "2.0.4" assert_matches = "1.5" insta = "1.34.0" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 705fdfb152df..f6ad5644fee9 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -18,6 +18,8 @@ tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } ollama-api-bindings = { path = "../ollama-api-bindings" } async-openai.workspace = true +tower = { workspace = true , features = ["limit", "util", "buffer"] } +tokio = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index 5dd4c1c2b76d..cf934243b037 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -1,23 +1,28 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use std::time; use tabby_inference::Embedding; -use crate::create_reqwest_client; +use crate::RateLimitedClient; pub struct LlamaCppEngine { - client: reqwest::Client, + client: RateLimitedClient, api_endpoint: String, - api_key: Option, } impl LlamaCppEngine { - pub fn create(api_endpoint: &str, api_key: Option) -> Self { - let client = create_reqwest_client(api_endpoint); + pub fn create(api_endpoint: &str, api_key: Option, num_request: u64, per: u64) -> Self { + let client = RateLimitedClient::new( + api_endpoint, + api_key, + num_request, + time::Duration::from_secs(per), + ) + .unwrap(); Self { client, api_endpoint: format!("{}/embedding", api_endpoint), - api_key, } } } @@ -39,12 +44,7 @@ impl Embedding for LlamaCppEngine { content: prompt.to_owned(), }; - let mut request = self.client.post(&self.api_endpoint).json(&request); - if let Some(api_key) = &self.api_key { - request = request.bearer_auth(api_key); - } - - let response = request.send().await?; + let response = self.client.post(&self.api_endpoint, request).await?; if response.status().is_server_error() { let error = response.text().await?; return Err(anyhow::anyhow!( @@ -68,8 +68,8 @@ mod tests { /// ./server -m ./models/nomic.gguf --port 8000 --embedding #[tokio::test] #[ignore] - async fn test_embedding() { - let engine = LlamaCppEngine::create("http://localhost:8000", None); + async fn test_embedding_no_limit() { + let engine = LlamaCppEngine::create("http://localhost:8000", None, 0, 0); let embedding = engine.embed("hello").await.unwrap(); assert_eq!(embedding.len(), 768); } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index fc42d123ec05..cbf3cca17f03 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -12,6 +12,11 @@ use tabby_inference::Embedding; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; pub async fn create(config: &HttpModelConfig) -> Arc { + let (num_request, per) = match &config.request_limit { + Some(limit) => (limit.num_request, limit.per), + _ => (0, 0), + }; + match config.kind.as_str() { "llama.cpp/embedding" => { let engine = LlamaCppEngine::create( @@ -20,6 +25,8 @@ pub async fn create(config: &HttpModelConfig) -> Arc { .as_deref() .expect("api_endpoint is required"), config.api_key.clone(), + num_request, + per, ); Arc::new(engine) } diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 41e7811421be..1e4f37ba166f 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -2,6 +2,12 @@ mod chat; mod completion; mod embedding; +use reqwest::{Response, Url}; +use serde::Serialize; +use std::{env, sync::Arc, time::Duration}; +use tokio::sync::Mutex; +use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt}; + pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; pub use embedding::create as create_embedding; @@ -19,3 +25,72 @@ fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client { builder.build().unwrap() } + +struct RateLimitedClient { + client: reqwest::Client, + rate_limit: Arc>>, + + api_key: Option, + + rate_limit_enabled: bool, +} + +impl RateLimitedClient { + pub fn new( + api_endpoint: &str, + api_key: Option, + num_request: u64, + per: Duration, + ) -> anyhow::Result { + if (num_request == 0) != (per.as_secs() == 0) { + anyhow::bail!("Both num_request and per must be zero or both must be non-zero"); + } + + let rate_limit_enabled = num_request > 0; + + let builder = reqwest::Client::builder(); + let is_localhost = api_endpoint.starts_with("http://localhost") + || api_endpoint.starts_with("http://127.0.0.1"); + let builder = if is_localhost { + builder.no_proxy() + } else { + builder + }; + + const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); + let client = builder.user_agent(USER_AGENT).tcp_nodelay(false).build()?; + + Ok(Self { + client: client.clone(), + rate_limit: Arc::new(Mutex::new( + ServiceBuilder::new() + .rate_limit(num_request, per) + .service(client), + )), + api_key, + rate_limit_enabled, + }) + } + + async fn post(&self, url: &str, body: T) -> anyhow::Result { + let url = Url::parse(url)?; + let mut builder = self.client.post(url.clone()).json(&body); + if let Some(api_key) = &self.api_key { + builder = builder.bearer_auth(api_key); + } + let request = builder.build()?; + + if self.rate_limit_enabled { + let future = self.rate_limit.lock().await.ready().await?.call(request); + future + .await + .and_then(|response| response.error_for_status()) + .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) + } else { + let response = self.client.execute(request).await?; + response + .error_for_status() + .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) + } + } +} diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 7f0df26a1349..237fd846a784 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -289,6 +289,9 @@ pub struct HttpModelConfig { #[builder(default)] pub api_key: Option, + #[builder(default)] + pub request_limit: Option, + /// Used by OpenAI style API for model name. #[builder(default)] pub model_name: Option, @@ -309,6 +312,14 @@ pub struct HttpModelConfig { pub additional_stop_words: Option>, } +#[derive(Serialize, Deserialize, Builder, Debug, Clone)] +pub struct RequestLimit { + // The limited number of requests can be made in following `per` time period. + pub num_request: u64, + // The time period in seconds to limit the number of requests. + pub per: u64, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct LocalModelConfig { pub model_id: String, From 796e9a270e2bce46003e3616ff1479854aba8e79 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:46:38 +0000 Subject: [PATCH 02/12] [autofix.ci] apply automated fixes --- crates/http-api-bindings/src/embedding/llama.rs | 3 ++- crates/http-api-bindings/src/lib.rs | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index cf934243b037..807050ef518b 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -1,6 +1,7 @@ +use std::time; + use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use std::time; use tabby_inference::Embedding; use crate::RateLimitedClient; diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 1e4f37ba166f..4b0871d5232a 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -2,15 +2,15 @@ mod chat; mod completion; mod embedding; -use reqwest::{Response, Url}; -use serde::Serialize; use std::{env, sync::Arc, time::Duration}; -use tokio::sync::Mutex; -use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt}; pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; pub use embedding::create as create_embedding; +use reqwest::{Response, Url}; +use serde::Serialize; +use tokio::sync::Mutex; +use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt}; fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client { let builder = reqwest::Client::builder(); From 83ab30c5053038bba39fdce04c5d2be871bc17db Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Fri, 22 Nov 2024 20:17:02 +0800 Subject: [PATCH 03/12] chore: use HttpClient trait for rate limit and client --- .../http-api-bindings/src/embedding/llama.rs | 37 ++++++++----- crates/http-api-bindings/src/lib.rs | 54 +++++++++++-------- 2 files changed, 54 insertions(+), 37 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index 807050ef518b..3fa2aff3d47d 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -2,28 +2,37 @@ use std::time; use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::sync::Arc; use tabby_inference::Embedding; -use crate::RateLimitedClient; +use crate::{create_reqwest_client, HttpClient, RateLimitedClient}; pub struct LlamaCppEngine { - client: RateLimitedClient, + client: Arc, api_endpoint: String, } impl LlamaCppEngine { pub fn create(api_endpoint: &str, api_key: Option, num_request: u64, per: u64) -> Self { - let client = RateLimitedClient::new( - api_endpoint, - api_key, - num_request, - time::Duration::from_secs(per), - ) - .unwrap(); - - Self { - client, - api_endpoint: format!("{}/embedding", api_endpoint), + if num_request > 0 { + let client = RateLimitedClient::new( + api_endpoint, + api_key, + num_request, + time::Duration::from_secs(per), + ) + .unwrap(); + Self { + client: Arc::new(client), + api_endpoint: format!("{}/embedding", api_endpoint), + } + } else { + let client = create_reqwest_client(api_endpoint); + Self { + client: Arc::new(client), + api_endpoint: format!("{}/embedding", api_endpoint), + } } } } @@ -45,7 +54,7 @@ impl Embedding for LlamaCppEngine { content: prompt.to_owned(), }; - let response = self.client.post(&self.api_endpoint, request).await?; + let response = self.client.post(&self.api_endpoint, json!(request)).await?; if response.status().is_server_error() { let error = response.text().await?; return Err(anyhow::anyhow!( diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 4b0871d5232a..7e1af54c5b64 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -2,16 +2,20 @@ mod chat; mod completion; mod embedding; -use std::{env, sync::Arc, time::Duration}; - pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; pub use embedding::create as create_embedding; use reqwest::{Response, Url}; -use serde::Serialize; +use serde_json::Value; +use std::{env, sync::Arc, time::Duration}; use tokio::sync::Mutex; use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt}; +#[async_trait::async_trait] +pub trait HttpClient: Send + Sync { + async fn post(&self, url: &str, body: Value) -> anyhow::Result; +} + fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client { let builder = reqwest::Client::builder(); @@ -31,8 +35,6 @@ struct RateLimitedClient { rate_limit: Arc>>, api_key: Option, - - rate_limit_enabled: bool, } impl RateLimitedClient { @@ -42,12 +44,10 @@ impl RateLimitedClient { num_request: u64, per: Duration, ) -> anyhow::Result { - if (num_request == 0) != (per.as_secs() == 0) { - anyhow::bail!("Both num_request and per must be zero or both must be non-zero"); + if num_request == 0 || per.as_secs() == 0 { + anyhow::bail!("Both num_request and per must be non-zero"); } - let rate_limit_enabled = num_request > 0; - let builder = reqwest::Client::builder(); let is_localhost = api_endpoint.starts_with("http://localhost") || api_endpoint.starts_with("http://127.0.0.1"); @@ -68,11 +68,13 @@ impl RateLimitedClient { .service(client), )), api_key, - rate_limit_enabled, }) } +} - async fn post(&self, url: &str, body: T) -> anyhow::Result { +#[async_trait::async_trait] +impl HttpClient for RateLimitedClient { + async fn post(&self, url: &str, body: Value) -> anyhow::Result { let url = Url::parse(url)?; let mut builder = self.client.post(url.clone()).json(&body); if let Some(api_key) = &self.api_key { @@ -80,17 +82,23 @@ impl RateLimitedClient { } let request = builder.build()?; - if self.rate_limit_enabled { - let future = self.rate_limit.lock().await.ready().await?.call(request); - future - .await - .and_then(|response| response.error_for_status()) - .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) - } else { - let response = self.client.execute(request).await?; - response - .error_for_status() - .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) - } + let future = self.rate_limit.lock().await.ready().await?.call(request); + let response = future.await?; + + response + .error_for_status() + .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) + } +} + +#[async_trait::async_trait] +impl HttpClient for reqwest::Client { + async fn post(&self, url: &str, body: Value) -> anyhow::Result { + let url = Url::parse(url)?; + let response = self.post(url).json(&body).send().await?; + + response + .error_for_status() + .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) } } From 3c3689c0344a423943e09390e38562fad4407e91 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:25:41 +0000 Subject: [PATCH 04/12] [autofix.ci] apply automated fixes --- crates/http-api-bindings/src/embedding/llama.rs | 3 +-- crates/http-api-bindings/src/lib.rs | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index 3fa2aff3d47d..322564a52d41 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -1,9 +1,8 @@ -use std::time; +use std::{sync::Arc, time}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::sync::Arc; use tabby_inference::Embedding; use crate::{create_reqwest_client, HttpClient, RateLimitedClient}; diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 7e1af54c5b64..30f936e6f41f 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -2,12 +2,13 @@ mod chat; mod completion; mod embedding; +use std::{env, sync::Arc, time::Duration}; + pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; pub use embedding::create as create_embedding; use reqwest::{Response, Url}; use serde_json::Value; -use std::{env, sync::Arc, time::Duration}; use tokio::sync::Mutex; use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt}; From a37eb8d8c6395c898cb8ed024478852cb08765c6 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Fri, 22 Nov 2024 21:33:44 +0800 Subject: [PATCH 05/12] feat: rate limit on embedding --- Cargo.lock | 1 + crates/http-api-bindings/Cargo.toml | 1 + .../http-api-bindings/src/embedding/llama.rs | 45 ++++------ crates/http-api-bindings/src/embedding/mod.rs | 82 +++++++++++++++++-- .../http-api-bindings/src/embedding/openai.rs | 2 +- .../http-api-bindings/src/embedding/voyage.rs | 2 +- crates/http-api-bindings/src/lib.rs | 79 +----------------- 7 files changed, 99 insertions(+), 113 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4befd5e13083..2e2849e3c1cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1901,6 +1901,7 @@ dependencies = [ "tabby-inference", "tokio", "tower 0.5.1", + "tracing", ] [[package]] diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index f6ad5644fee9..0f86424886a6 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -20,6 +20,7 @@ ollama-api-bindings = { path = "../ollama-api-bindings" } async-openai.workspace = true tower = { workspace = true , features = ["limit", "util", "buffer"] } tokio = { workspace = true } +tracing = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index 322564a52d41..b3ea0f5aac70 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -1,37 +1,23 @@ -use std::{sync::Arc, time}; - use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use serde_json::json; use tabby_inference::Embedding; -use crate::{create_reqwest_client, HttpClient, RateLimitedClient}; +use crate::create_reqwest_client; pub struct LlamaCppEngine { - client: Arc, + client: reqwest::Client, api_endpoint: String, + api_key: Option, } impl LlamaCppEngine { - pub fn create(api_endpoint: &str, api_key: Option, num_request: u64, per: u64) -> Self { - if num_request > 0 { - let client = RateLimitedClient::new( - api_endpoint, - api_key, - num_request, - time::Duration::from_secs(per), - ) - .unwrap(); - Self { - client: Arc::new(client), - api_endpoint: format!("{}/embedding", api_endpoint), - } - } else { - let client = create_reqwest_client(api_endpoint); - Self { - client: Arc::new(client), - api_endpoint: format!("{}/embedding", api_endpoint), - } + pub fn create(api_endpoint: &str, api_key: Option) -> impl Embedding { + let client = create_reqwest_client(api_endpoint); + + Self { + client: client, + api_endpoint: format!("{}/embedding", api_endpoint), + api_key, } } } @@ -53,7 +39,12 @@ impl Embedding for LlamaCppEngine { content: prompt.to_owned(), }; - let response = self.client.post(&self.api_endpoint, json!(request)).await?; + let mut request = self.client.post(&self.api_endpoint).json(&request); + if let Some(api_key) = &self.api_key { + request = request.bearer_auth(api_key); + } + + let response = request.send().await?; if response.status().is_server_error() { let error = response.text().await?; return Err(anyhow::anyhow!( @@ -77,8 +68,8 @@ mod tests { /// ./server -m ./models/nomic.gguf --port 8000 --embedding #[tokio::test] #[ignore] - async fn test_embedding_no_limit() { - let engine = LlamaCppEngine::create("http://localhost:8000", None, 0, 0); + async fn test_embedding() { + let engine = LlamaCppEngine::create("http://localhost:8000", None); let embedding = engine.embed("hello").await.unwrap(); assert_eq!(embedding.len(), 768); } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index cbf3cca17f03..3d8af2dd2f5b 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -2,12 +2,17 @@ mod llama; mod openai; mod voyage; +use async_trait::async_trait; use core::panic; -use std::sync::Arc; - +use futures::future::BoxFuture; use llama::LlamaCppEngine; +use std::task::{Context, Poll}; +use std::{sync::Arc, time::Duration}; use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; +use tokio::sync::Mutex; +use tower::{Service, ServiceBuilder, ServiceExt}; +use tracing::debug; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; @@ -17,7 +22,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc { _ => (0, 0), }; - match config.kind.as_str() { + let embedding = match config.kind.as_str() { "llama.cpp/embedding" => { let engine = LlamaCppEngine::create( config @@ -25,11 +30,10 @@ pub async fn create(config: &HttpModelConfig) -> Arc { .as_deref() .expect("api_endpoint is required"), config.api_key.clone(), - num_request, - per, ); Arc::new(engine) } + "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, "openai/embedding" => { let engine = OpenAIEmbeddingEngine::create( config @@ -41,7 +45,6 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ); Arc::new(engine) } - "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, "voyage/embedding" => { let engine = VoyageEmbeddingEngine::create( config.api_endpoint.as_deref(), @@ -60,5 +63,72 @@ pub async fn create(config: &HttpModelConfig) -> Arc { "Unsupported kind for http embedding model: {}", unsupported_kind ), + }; + + if num_request > 0 { + debug!( + "Creating rate limited embedding with {} requests per {}s", + num_request, per, + ); + Arc::new( + RateLimitedEmbedding::new(embedding, num_request, Duration::from_secs(per)) + .expect("Failed to create rate limited embedding"), + ) + } else { + embedding + } +} + +struct EmbeddingService { + embedding: Arc, +} + +impl Service for EmbeddingService { + type Response = Vec; + type Error = anyhow::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, prompt: String) -> Self::Future { + let embedding = self.embedding.clone(); + Box::pin(async move { embedding.embed(&prompt).await }) + } +} + +pub struct RateLimitedEmbedding { + embedding: Arc, anyhow::Error>>>, +} + +impl RateLimitedEmbedding { + pub fn new( + embedding: Arc, + num_request: u64, + per: Duration, + ) -> anyhow::Result { + if num_request == 0 || per.as_secs() == 0 { + anyhow::bail!("Both num_request and per must be non-zero"); + } + + let service = ServiceBuilder::new() + .rate_limit(num_request, per) + .service(EmbeddingService { embedding }) + .boxed(); + + Ok(Self { + embedding: Arc::new(Mutex::new(service)), + }) + } +} + +#[async_trait] +impl Embedding for RateLimitedEmbedding { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + let mut service = self.embedding.lock().await; + let prompt_owned = prompt.to_string(); + let response = service.ready().await?.call(prompt_owned).await?; + Ok(response) } } diff --git a/crates/http-api-bindings/src/embedding/openai.rs b/crates/http-api-bindings/src/embedding/openai.rs index d524fd70ee07..40fae6368e86 100644 --- a/crates/http-api-bindings/src/embedding/openai.rs +++ b/crates/http-api-bindings/src/embedding/openai.rs @@ -12,7 +12,7 @@ pub struct OpenAIEmbeddingEngine { } impl OpenAIEmbeddingEngine { - pub fn create(api_endpoint: &str, model_name: &str, api_key: Option<&str>) -> Self { + pub fn create(api_endpoint: &str, model_name: &str, api_key: Option<&str>) -> impl Embedding { let config = OpenAIConfig::default() .with_api_base(api_endpoint) .with_api_key(api_key.unwrap_or_default()); diff --git a/crates/http-api-bindings/src/embedding/voyage.rs b/crates/http-api-bindings/src/embedding/voyage.rs index e403d790010e..b813d811fae8 100644 --- a/crates/http-api-bindings/src/embedding/voyage.rs +++ b/crates/http-api-bindings/src/embedding/voyage.rs @@ -14,7 +14,7 @@ pub struct VoyageEmbeddingEngine { } impl VoyageEmbeddingEngine { - pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> Self { + pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> impl Embedding { let api_endpoint = api_endpoint.unwrap_or(DEFAULT_VOYAGE_API_ENDPOINT); let client = Client::new(); Self { diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 30f936e6f41f..2c68ebe7be06 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -2,15 +2,11 @@ mod chat; mod completion; mod embedding; -use std::{env, sync::Arc, time::Duration}; - pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; pub use embedding::create as create_embedding; -use reqwest::{Response, Url}; +use reqwest::Response; use serde_json::Value; -use tokio::sync::Mutex; -use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt}; #[async_trait::async_trait] pub trait HttpClient: Send + Sync { @@ -30,76 +26,3 @@ fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client { builder.build().unwrap() } - -struct RateLimitedClient { - client: reqwest::Client, - rate_limit: Arc>>, - - api_key: Option, -} - -impl RateLimitedClient { - pub fn new( - api_endpoint: &str, - api_key: Option, - num_request: u64, - per: Duration, - ) -> anyhow::Result { - if num_request == 0 || per.as_secs() == 0 { - anyhow::bail!("Both num_request and per must be non-zero"); - } - - let builder = reqwest::Client::builder(); - let is_localhost = api_endpoint.starts_with("http://localhost") - || api_endpoint.starts_with("http://127.0.0.1"); - let builder = if is_localhost { - builder.no_proxy() - } else { - builder - }; - - const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); - let client = builder.user_agent(USER_AGENT).tcp_nodelay(false).build()?; - - Ok(Self { - client: client.clone(), - rate_limit: Arc::new(Mutex::new( - ServiceBuilder::new() - .rate_limit(num_request, per) - .service(client), - )), - api_key, - }) - } -} - -#[async_trait::async_trait] -impl HttpClient for RateLimitedClient { - async fn post(&self, url: &str, body: Value) -> anyhow::Result { - let url = Url::parse(url)?; - let mut builder = self.client.post(url.clone()).json(&body); - if let Some(api_key) = &self.api_key { - builder = builder.bearer_auth(api_key); - } - let request = builder.build()?; - - let future = self.rate_limit.lock().await.ready().await?.call(request); - let response = future.await?; - - response - .error_for_status() - .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) - } -} - -#[async_trait::async_trait] -impl HttpClient for reqwest::Client { - async fn post(&self, url: &str, body: Value) -> anyhow::Result { - let url = Url::parse(url)?; - let response = self.post(url).json(&body).send().await?; - - response - .error_for_status() - .map_err(|err| anyhow::anyhow!("Error from server: {}", err)) - } -} From d46923c89a5f28d91ec8a4fec6a90ce334ea43e2 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:41:31 +0000 Subject: [PATCH 06/12] [autofix.ci] apply automated fixes --- crates/http-api-bindings/src/embedding/llama.rs | 2 +- crates/http-api-bindings/src/embedding/mod.rs | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index b3ea0f5aac70..041411020155 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -15,7 +15,7 @@ impl LlamaCppEngine { let client = create_reqwest_client(api_endpoint); Self { - client: client, + client, api_endpoint: format!("{}/embedding", api_endpoint), api_key, } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 3d8af2dd2f5b..d60316d8ce52 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -2,12 +2,16 @@ mod llama; mod openai; mod voyage; -use async_trait::async_trait; use core::panic; +use std::{ + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use async_trait::async_trait; use futures::future::BoxFuture; use llama::LlamaCppEngine; -use std::task::{Context, Poll}; -use std::{sync::Arc, time::Duration}; use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; use tokio::sync::Mutex; From 01a7ac1cabc875c502fa241ab000b6fdd816ab11 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Sun, 24 Nov 2024 23:29:47 +0800 Subject: [PATCH 07/12] config(models-http-api): use rpm instead of request and per --- crates/http-api-bindings/src/embedding/mod.rs | 32 +++++++++---------- crates/tabby-common/src/config.rs | 6 ++-- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index d60316d8ce52..e84a3f797427 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -2,11 +2,10 @@ mod llama; mod openai; mod voyage; -use core::panic; +use core::{panic, time}; use std::{ sync::Arc, task::{Context, Poll}, - time::Duration, }; use async_trait::async_trait; @@ -21,9 +20,10 @@ use tracing::debug; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; pub async fn create(config: &HttpModelConfig) -> Arc { - let (num_request, per) = match &config.request_limit { - Some(limit) => (limit.num_request, limit.per), - _ => (0, 0), + let rpm = if let Some(limit) = &config.request_limit { + limit.request_per_minute + } else { + 0 }; let embedding = match config.kind.as_str() { @@ -69,13 +69,13 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ), }; - if num_request > 0 { + if rpm > 0 { debug!( - "Creating rate limited embedding with {} requests per {}s", - num_request, per, + "Creating rate limited embedding with {} requests per minute", + rpm, ); Arc::new( - RateLimitedEmbedding::new(embedding, num_request, Duration::from_secs(per)) + RateLimitedEmbedding::new(embedding, rpm) .expect("Failed to create rate limited embedding"), ) } else { @@ -107,17 +107,15 @@ pub struct RateLimitedEmbedding { } impl RateLimitedEmbedding { - pub fn new( - embedding: Arc, - num_request: u64, - per: Duration, - ) -> anyhow::Result { - if num_request == 0 || per.as_secs() == 0 { - anyhow::bail!("Both num_request and per must be non-zero"); + pub fn new(embedding: Arc, rpm: u64) -> anyhow::Result { + if rpm == 0 { + anyhow::bail!( + "Can not create rate limited embedding client with 0 requests per minute" + ); } let service = ServiceBuilder::new() - .rate_limit(num_request, per) + .rate_limit(rpm, time::Duration::from_secs(60)) .service(EmbeddingService { embedding }) .boxed(); diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 237fd846a784..c1832e4f2ac9 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -314,10 +314,8 @@ pub struct HttpModelConfig { #[derive(Serialize, Deserialize, Builder, Debug, Clone)] pub struct RequestLimit { - // The limited number of requests can be made in following `per` time period. - pub num_request: u64, - // The time period in seconds to limit the number of requests. - pub per: u64, + // The limited number of requests can be made in one minute. + pub request_per_minute: u64, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] From afe48c3600f378d76179d0a4b30cb994249c8a0f Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Mon, 25 Nov 2024 10:02:45 +0800 Subject: [PATCH 08/12] chore: 600 rpm for embedding by default --- .../http-api-bindings/src/embedding/llama.rs | 2 +- crates/http-api-bindings/src/embedding/mod.rs | 90 +++---------------- .../http-api-bindings/src/embedding/openai.rs | 2 +- .../src/embedding/rate_limit.rs | 64 +++++++++++++ .../http-api-bindings/src/embedding/voyage.rs | 2 +- crates/http-api-bindings/src/lib.rs | 7 -- 6 files changed, 81 insertions(+), 86 deletions(-) create mode 100644 crates/http-api-bindings/src/embedding/rate_limit.rs diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs index 041411020155..5dd4c1c2b76d 100644 --- a/crates/http-api-bindings/src/embedding/llama.rs +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -11,7 +11,7 @@ pub struct LlamaCppEngine { } impl LlamaCppEngine { - pub fn create(api_endpoint: &str, api_key: Option) -> impl Embedding { + pub fn create(api_endpoint: &str, api_key: Option) -> Self { let client = create_reqwest_client(api_endpoint); Self { diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index e84a3f797427..25717e259ad2 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,20 +1,15 @@ mod llama; mod openai; +mod rate_limit; mod voyage; -use core::{panic, time}; -use std::{ - sync::Arc, - task::{Context, Poll}, -}; +use core::panic; +use std::sync::Arc; -use async_trait::async_trait; -use futures::future::BoxFuture; use llama::LlamaCppEngine; +use rate_limit::RateLimitedEmbedding; use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; -use tokio::sync::Mutex; -use tower::{Service, ServiceBuilder, ServiceExt}; use tracing::debug; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; @@ -23,10 +18,14 @@ pub async fn create(config: &HttpModelConfig) -> Arc { let rpm = if let Some(limit) = &config.request_limit { limit.request_per_minute } else { - 0 + debug!( + "No request limit specified for model {}, defaulting to 600 rpm", + config.kind + ); + 600 }; - let embedding = match config.kind.as_str() { + let embedding: Arc = match config.kind.as_str() { "llama.cpp/embedding" => { let engine = LlamaCppEngine::create( config @@ -37,7 +36,6 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ); Arc::new(engine) } - "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, "openai/embedding" => { let engine = OpenAIEmbeddingEngine::create( config @@ -49,6 +47,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ); Arc::new(engine) } + "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, "voyage/embedding" => { let engine = VoyageEmbeddingEngine::create( config.api_endpoint.as_deref(), @@ -69,68 +68,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc { ), }; - if rpm > 0 { - debug!( - "Creating rate limited embedding with {} requests per minute", - rpm, - ); - Arc::new( - RateLimitedEmbedding::new(embedding, rpm) - .expect("Failed to create rate limited embedding"), - ) - } else { - embedding - } -} - -struct EmbeddingService { - embedding: Arc, -} - -impl Service for EmbeddingService { - type Response = Vec; - type Error = anyhow::Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, prompt: String) -> Self::Future { - let embedding = self.embedding.clone(); - Box::pin(async move { embedding.embed(&prompt).await }) - } -} - -pub struct RateLimitedEmbedding { - embedding: Arc, anyhow::Error>>>, -} - -impl RateLimitedEmbedding { - pub fn new(embedding: Arc, rpm: u64) -> anyhow::Result { - if rpm == 0 { - anyhow::bail!( - "Can not create rate limited embedding client with 0 requests per minute" - ); - } - - let service = ServiceBuilder::new() - .rate_limit(rpm, time::Duration::from_secs(60)) - .service(EmbeddingService { embedding }) - .boxed(); - - Ok(Self { - embedding: Arc::new(Mutex::new(service)), - }) - } -} - -#[async_trait] -impl Embedding for RateLimitedEmbedding { - async fn embed(&self, prompt: &str) -> anyhow::Result> { - let mut service = self.embedding.lock().await; - let prompt_owned = prompt.to_string(); - let response = service.ready().await?.call(prompt_owned).await?; - Ok(response) - } + Arc::new( + RateLimitedEmbedding::new(embedding, rpm).expect("Failed to create rate limited embedding"), + ) } diff --git a/crates/http-api-bindings/src/embedding/openai.rs b/crates/http-api-bindings/src/embedding/openai.rs index 40fae6368e86..d524fd70ee07 100644 --- a/crates/http-api-bindings/src/embedding/openai.rs +++ b/crates/http-api-bindings/src/embedding/openai.rs @@ -12,7 +12,7 @@ pub struct OpenAIEmbeddingEngine { } impl OpenAIEmbeddingEngine { - pub fn create(api_endpoint: &str, model_name: &str, api_key: Option<&str>) -> impl Embedding { + pub fn create(api_endpoint: &str, model_name: &str, api_key: Option<&str>) -> Self { let config = OpenAIConfig::default() .with_api_base(api_endpoint) .with_api_key(api_key.unwrap_or_default()); diff --git a/crates/http-api-bindings/src/embedding/rate_limit.rs b/crates/http-api-bindings/src/embedding/rate_limit.rs new file mode 100644 index 000000000000..ddfe0ef891ce --- /dev/null +++ b/crates/http-api-bindings/src/embedding/rate_limit.rs @@ -0,0 +1,64 @@ +use async_trait::async_trait; +use futures::future::BoxFuture; +use tokio::sync::Mutex; +use tower::{Service, ServiceBuilder, ServiceExt}; + +use core::time; +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use tabby_inference::Embedding; + +struct EmbeddingService { + embedding: Arc, +} + +impl Service for EmbeddingService { + type Response = Vec; + type Error = anyhow::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, prompt: String) -> Self::Future { + let embedding = self.embedding.clone(); + Box::pin(async move { embedding.embed(&prompt).await }) + } +} + +pub struct RateLimitedEmbedding { + embedding: Arc, anyhow::Error>>>, +} + +impl RateLimitedEmbedding { + pub fn new(embedding: Arc, rpm: u64) -> anyhow::Result { + if rpm == 0 { + anyhow::bail!( + "Can not create rate limited embedding client with 0 requests per minute" + ); + } + + let service = ServiceBuilder::new() + .rate_limit(rpm, time::Duration::from_secs(60)) + .service(EmbeddingService { embedding }) + .boxed(); + + Ok(Self { + embedding: Arc::new(Mutex::new(service)), + }) + } +} + +#[async_trait] +impl Embedding for RateLimitedEmbedding { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + let mut service = self.embedding.lock().await; + let prompt_owned = prompt.to_string(); + let response = service.ready().await?.call(prompt_owned).await?; + Ok(response) + } +} diff --git a/crates/http-api-bindings/src/embedding/voyage.rs b/crates/http-api-bindings/src/embedding/voyage.rs index b813d811fae8..e403d790010e 100644 --- a/crates/http-api-bindings/src/embedding/voyage.rs +++ b/crates/http-api-bindings/src/embedding/voyage.rs @@ -14,7 +14,7 @@ pub struct VoyageEmbeddingEngine { } impl VoyageEmbeddingEngine { - pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> impl Embedding { + pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> Self { let api_endpoint = api_endpoint.unwrap_or(DEFAULT_VOYAGE_API_ENDPOINT); let client = Client::new(); Self { diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index 2c68ebe7be06..41e7811421be 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -5,13 +5,6 @@ mod embedding; pub use chat::create as create_chat; pub use completion::{build_completion_prompt, create}; pub use embedding::create as create_embedding; -use reqwest::Response; -use serde_json::Value; - -#[async_trait::async_trait] -pub trait HttpClient: Send + Sync { - async fn post(&self, url: &str, body: Value) -> anyhow::Result; -} fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client { let builder = reqwest::Client::builder(); From dc11cfb82e8923ab74b61418e967125d95f475f7 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 02:08:17 +0000 Subject: [PATCH 09/12] [autofix.ci] apply automated fixes --- crates/http-api-bindings/src/embedding/rate_limit.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/rate_limit.rs b/crates/http-api-bindings/src/embedding/rate_limit.rs index ddfe0ef891ce..9f1c5d256416 100644 --- a/crates/http-api-bindings/src/embedding/rate_limit.rs +++ b/crates/http-api-bindings/src/embedding/rate_limit.rs @@ -1,15 +1,14 @@ -use async_trait::async_trait; -use futures::future::BoxFuture; -use tokio::sync::Mutex; -use tower::{Service, ServiceBuilder, ServiceExt}; - use core::time; use std::{ sync::Arc, task::{Context, Poll}, }; +use async_trait::async_trait; +use futures::future::BoxFuture; use tabby_inference::Embedding; +use tokio::sync::Mutex; +use tower::{Service, ServiceBuilder, ServiceExt}; struct EmbeddingService { embedding: Arc, From fdeedddd04524da1cb2df7881e5ed6824dbeeec7 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Mon, 25 Nov 2024 10:53:29 +0800 Subject: [PATCH 10/12] chore: use rate limit as config key --- crates/http-api-bindings/src/embedding/mod.rs | 16 +++++----------- .../src/embedding/rate_limit.rs | 2 +- crates/tabby-common/src/config.rs | 12 ++++++++++-- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 25717e259ad2..1d6e3cb3d333 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -8,22 +8,16 @@ use std::sync::Arc; use llama::LlamaCppEngine; use rate_limit::RateLimitedEmbedding; -use tabby_common::config::HttpModelConfig; +use tabby_common::config::{HttpModelConfig, RateLimit}; use tabby_inference::Embedding; -use tracing::debug; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; pub async fn create(config: &HttpModelConfig) -> Arc { - let rpm = if let Some(limit) = &config.request_limit { - limit.request_per_minute - } else { - debug!( - "No request limit specified for model {}, defaulting to 600 rpm", - config.kind - ); - 600 - }; + let rpm = config.rate_limit.as_ref().map_or_else( + || RateLimit::default().request_per_minute, + |rl| rl.request_per_minute, + ); let embedding: Arc = match config.kind.as_str() { "llama.cpp/embedding" => { diff --git a/crates/http-api-bindings/src/embedding/rate_limit.rs b/crates/http-api-bindings/src/embedding/rate_limit.rs index 9f1c5d256416..0f4df7c4ddb2 100644 --- a/crates/http-api-bindings/src/embedding/rate_limit.rs +++ b/crates/http-api-bindings/src/embedding/rate_limit.rs @@ -1,7 +1,7 @@ -use core::time; use std::{ sync::Arc, task::{Context, Poll}, + time, }; use async_trait::async_trait; diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index c1832e4f2ac9..56acd7590ae8 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -290,7 +290,7 @@ pub struct HttpModelConfig { pub api_key: Option, #[builder(default)] - pub request_limit: Option, + pub rate_limit: Option, /// Used by OpenAI style API for model name. #[builder(default)] @@ -313,11 +313,19 @@ pub struct HttpModelConfig { } #[derive(Serialize, Deserialize, Builder, Debug, Clone)] -pub struct RequestLimit { +pub struct RateLimit { // The limited number of requests can be made in one minute. pub request_per_minute: u64, } +impl Default for RateLimit { + fn default() -> Self { + Self { + request_per_minute: 600, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct LocalModelConfig { pub model_id: String, From 76ca9d611f6fa777d95e4c2223f81b5bad9eb04e Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 02:59:08 +0000 Subject: [PATCH 11/12] [autofix.ci] apply automated fixes --- Cargo.lock | 1 - crates/http-api-bindings/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2e2849e3c1cb..4befd5e13083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1901,7 +1901,6 @@ dependencies = [ "tabby-inference", "tokio", "tower 0.5.1", - "tracing", ] [[package]] diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 0f86424886a6..f6ad5644fee9 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -20,7 +20,6 @@ ollama-api-bindings = { path = "../ollama-api-bindings" } async-openai.workspace = true tower = { workspace = true , features = ["limit", "util", "buffer"] } tokio = { workspace = true } -tracing = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = ["rt", "macros"] } From 3d902968d1a56399bf2b9d2d762e04bcce261c09 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Mon, 25 Nov 2024 11:00:55 +0800 Subject: [PATCH 12/12] chore: use serde default for rate limit --- crates/http-api-bindings/src/embedding/mod.rs | 7 ++----- crates/tabby-common/src/config.rs | 3 ++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 1d6e3cb3d333..d5b8920af72f 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -8,16 +8,13 @@ use std::sync::Arc; use llama::LlamaCppEngine; use rate_limit::RateLimitedEmbedding; -use tabby_common::config::{HttpModelConfig, RateLimit}; +use tabby_common::config::HttpModelConfig; use tabby_inference::Embedding; use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; pub async fn create(config: &HttpModelConfig) -> Arc { - let rpm = config.rate_limit.as_ref().map_or_else( - || RateLimit::default().request_per_minute, - |rl| rl.request_per_minute, - ); + let rpm = config.rate_limit.request_per_minute; let embedding: Arc = match config.kind.as_str() { "llama.cpp/embedding" => { diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 56acd7590ae8..4d32b4221904 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -290,7 +290,8 @@ pub struct HttpModelConfig { pub api_key: Option, #[builder(default)] - pub rate_limit: Option, + #[serde(default)] + pub rate_limit: RateLimit, /// Used by OpenAI style API for model name. #[builder(default)]