Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(models-http-api): add rate limit in embedding api #3455

Closed
wants to merge 12 commits into from
38 changes: 28 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ url = "2.5.0"
temp_testdir = "0.2"
git2 = "0.18.3"
tower-http = "0.5"
tower = "0.5"
mime_guess = "2.0.4"
assert_matches = "1.5"
insta = "1.34.0"
Expand Down
2 changes: 2 additions & 0 deletions crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }
ollama-api-bindings = { path = "../ollama-api-bindings" }
async-openai.workspace = true
tower = { workspace = true , features = ["limit", "util", "buffer"] }
tokio = { workspace = true }

[dev-dependencies]
tokio = { workspace = true, features = ["rt", "macros"] }
12 changes: 10 additions & 2 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
mod llama;
mod openai;
mod rate_limit;
mod voyage;

use core::panic;
use std::sync::Arc;

use llama::LlamaCppEngine;
use rate_limit::RateLimitedEmbedding;
use tabby_common::config::HttpModelConfig;
use tabby_inference::Embedding;

use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};

pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
match config.kind.as_str() {
let rpm = config.rate_limit.request_per_minute;

let embedding: Arc<dyn Embedding> = match config.kind.as_str() {
"llama.cpp/embedding" => {
let engine = LlamaCppEngine::create(
config
Expand Down Expand Up @@ -53,5 +57,9 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
"Unsupported kind for http embedding model: {}",
unsupported_kind
),
}
};

Arc::new(
RateLimitedEmbedding::new(embedding, rpm).expect("Failed to create rate limited embedding"),
)
}
63 changes: 63 additions & 0 deletions crates/http-api-bindings/src/embedding/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::{
sync::Arc,
task::{Context, Poll},
time,
};

use async_trait::async_trait;
use futures::future::BoxFuture;
use tabby_inference::Embedding;
use tokio::sync::Mutex;
use tower::{Service, ServiceBuilder, ServiceExt};

struct EmbeddingService {
embedding: Arc<dyn Embedding>,
zwpaper marked this conversation as resolved.
Show resolved Hide resolved
}

impl Service<String> for EmbeddingService {
type Response = Vec<f32>;
type Error = anyhow::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, prompt: String) -> Self::Future {
let embedding = self.embedding.clone();
Box::pin(async move { embedding.embed(&prompt).await })
}
}

pub struct RateLimitedEmbedding {
embedding: Arc<Mutex<tower::util::BoxService<String, Vec<f32>, anyhow::Error>>>,
}

impl RateLimitedEmbedding {
pub fn new(embedding: Arc<dyn Embedding>, rpm: u64) -> anyhow::Result<Self> {
if rpm == 0 {
anyhow::bail!(
"Can not create rate limited embedding client with 0 requests per minute"
);
}

let service = ServiceBuilder::new()
.rate_limit(rpm, time::Duration::from_secs(60))
.service(EmbeddingService { embedding })
.boxed();

Ok(Self {
embedding: Arc::new(Mutex::new(service)),
})
}
}

#[async_trait]
impl Embedding for RateLimitedEmbedding {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
let mut service = self.embedding.lock().await;
let prompt_owned = prompt.to_string();
let response = service.ready().await?.call(prompt_owned).await?;
Ok(response)
}
}
18 changes: 18 additions & 0 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ pub struct HttpModelConfig {
#[builder(default)]
pub api_key: Option<String>,

#[builder(default)]
#[serde(default)]
pub rate_limit: RateLimit,

/// Used by OpenAI style API for model name.
#[builder(default)]
pub model_name: Option<String>,
Expand All @@ -309,6 +313,20 @@ pub struct HttpModelConfig {
pub additional_stop_words: Option<Vec<String>>,
}

#[derive(Serialize, Deserialize, Builder, Debug, Clone)]
pub struct RateLimit {
// The limited number of requests can be made in one minute.
pub request_per_minute: u64,
}

impl Default for RateLimit {
fn default() -> Self {
Self {
request_per_minute: 600,
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct LocalModelConfig {
pub model_id: String,
Expand Down
Loading