Skip to content

Commit

Permalink
refactor(core): remove CodeRepositoryAccess, add AllowedCodeRepositor… (
Browse files Browse the repository at this point in the history
#3185)

* refactor(core): remove CodeRepositoryAccess, add AllowedCodeRepository as axum Extension

* update

* Update ee/tabby-webserver/src/service/user_group.rs

* refactor(webserver): remove more useless fields in Webserver

* cleanup
  • Loading branch information
wsxiaoys authored Sep 23, 2024
1 parent bb7965d commit 385d7ae
Show file tree
Hide file tree
Showing 17 changed files with 290 additions and 282 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/tabby-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ tracing.workspace = true
chrono.workspace = true
axum.workspace = true
axum-extra = { workspace = true, features = ["typed-header"] }
parse-git-url = "0.5.1"

[dev-dependencies]
temp_testdir = { workspace = true }
Expand Down
7 changes: 2 additions & 5 deletions crates/tabby-common/src/api/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ pub enum CodeSearchError {

#[derive(Deserialize, ToSchema)]
pub struct CodeSearchQuery {
pub git_url: Option<String>,
pub filepath: Option<String>,
pub language: Option<String>,
pub content: String,
Expand All @@ -69,18 +68,16 @@ pub struct CodeSearchQuery {

impl CodeSearchQuery {
pub fn new(
git_url: Option<String>,
filepath: Option<String>,
language: Option<String>,
content: String,
source_id: Option<String>,
source_id: String,
) -> Self {
Self {
git_url,
filepath,
language,
content,
source_id: source_id.unwrap_or_default(),
source_id,
}
}
}
Expand Down
158 changes: 157 additions & 1 deletion crates/tabby-common/src/axum.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use axum::http::HeaderName;
use axum_extra::headers::Header;

use crate::constants::USER_HEADER_FIELD_NAME;
use crate::{config::CodeRepository, constants::USER_HEADER_FIELD_NAME};

#[derive(Debug)]
pub struct MaybeUser(pub Option<String>);
Expand Down Expand Up @@ -29,3 +29,159 @@ impl Header for MaybeUser {
todo!()
}
}

#[derive(Debug, Default, Clone)]
pub struct AllowedCodeRepository {
list: Vec<CodeRepository>,
}

impl AllowedCodeRepository {
pub fn new(list: Vec<CodeRepository>) -> Self {
Self { list }
}

pub fn new_from_config() -> Self {
let list = crate::config::Config::load()
.map(|x| {
x.repositories
.into_iter()
.enumerate()
.map(|(i, repo)| {
CodeRepository::new(repo.git_url(), &crate::config::config_index_to_id(i))
})
.collect()
})
.unwrap_or_default();

Self { list }
}
pub fn closest_match(&self, git_url: &str) -> Option<&str> {
closest_match(git_url, self.list.iter())
}
}

fn closest_match<'a>(
git_url: &str,
repositories: impl IntoIterator<Item = &'a CodeRepository>,
) -> Option<&'a str> {
let git_search = parse_git_url::GitUrl::parse(git_url).ok()?;

repositories
.into_iter()
.filter(|elem| {
parse_git_url::GitUrl::parse(&elem.git_url).is_ok_and(|x| x.name == git_search.name)
})
// If there're multiple matches, we pick the one with highest alphabetical order
.min_by_key(|elem| elem.canonical_git_url())
.map(|x| x.source_id.as_str())
}

#[cfg(test)]
mod tests {
use super::*;

macro_rules! assert_match_first {
($query:literal, $candidates:expr) => {
let candidates: Vec<_> = $candidates
.into_iter()
.enumerate()
.map(|(i, x)| CodeRepository::new(&x, &crate::config::config_index_to_id(i)))
.collect();
let expect = &candidates[0];
assert_eq!(
closest_match($query, &candidates),
Some(expect.source_id.as_ref())
);
};
}

macro_rules! assert_match_none {
($query:literal, $candidates:expr) => {
let candidates: Vec<_> = $candidates
.into_iter()
.enumerate()
.map(|(i, x)| CodeRepository::new(&x, &crate::config::config_index_to_id(i)))
.collect();
assert_eq!(closest_match($query, &candidates), None);
};
}

#[test]
fn test_closest_match() {
// Test .git suffix should still match
assert_match_first!(
"https://github.com/example/test.git",
["https://github.com/example/test"]
);

// Test auth in URL should still match
assert_match_first!(
"https://[email protected]/example/test",
["https://github.com/example/test"]
);

// Test name must be exact match
assert_match_none!(
"https://github.com/example/another-repo",
["https://github.com/example/anoth-repo"]
);

// Test different repositories with a common prefix should not match
assert_match_none!(
"https://github.com/TabbyML/tabby",
["https://github.com/TabbyML/registry-tabby"]
);

// Test entirely different repository names should not match
assert_match_none!(
"https://github.com/TabbyML/tabby",
["https://github.com/TabbyML/uptime"]
);

assert_match_none!("https://github.com", ["https://github.com/TabbyML/tabby"]);

// Test different host
assert_match_first!(
"https://bitbucket.com/TabbyML/tabby",
["https://github.com/TabbyML/tabby"]
);

// Test multiple close matches
assert_match_none!(
"[email protected]:TabbyML/tabby",
[
"https://bitbucket.com/CrabbyML/crabby",
"https://gitlab.com/TabbyML/registry-tabby",
]
);
}

#[test]
fn test_closest_match_url_format_differences() {
// Test different protocol and suffix should still match
assert_match_first!(
"[email protected]:TabbyML/tabby.git",
["https://github.com/TabbyML/tabby"]
);

// Test different protocol should still match
assert_match_first!(
"[email protected]:TabbyML/tabby",
["https://github.com/TabbyML/tabby"]
);

// Test URL without organization should still match
assert_match_first!(
"https://custom-git.com/tabby",
["https://custom-git.com/TabbyML/tabby"]
);
}

#[test]
fn test_closest_match_local_url() {
assert_match_first!(
"[email protected]:TabbyML/tabby.git",
["file:///home/TabbyML/tabby"]
);
}
}
20 changes: 0 additions & 20 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{collections::HashSet, path::PathBuf};

use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use derive_builder::Builder;
use hash_ids::HashIds;
use lazy_static::lazy_static;
Expand Down Expand Up @@ -351,25 +350,6 @@ impl CodeRepository {
}
}

#[async_trait]
pub trait CodeRepositoryAccess: Send + Sync {
async fn repositories(&self) -> Result<Vec<CodeRepository>>;
}

pub struct StaticCodeRepositoryAccess;

#[async_trait]
impl CodeRepositoryAccess for StaticCodeRepositoryAccess {
async fn repositories(&self) -> Result<Vec<CodeRepository>> {
Ok(Config::load()?
.repositories
.into_iter()
.enumerate()
.map(|(i, repo)| CodeRepository::new(&repo.git_url, &config_index_to_id(i)))
.collect())
}
}

#[cfg(test)]
mod tests {
use super::{sanitize_name, Config, RepositoryConfig};
Expand Down
2 changes: 0 additions & 2 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ thiserror.workspace = true
chrono.workspace = true
axum-prometheus = "0.6"
uuid.workspace = true
cached = { workspace = true, features = ["async"] }
parse-git-url = "0.5.1"
color-eyre = { version = "0.6.3" }
reqwest.workspace = true
async-openai.workspace = true
Expand Down
10 changes: 7 additions & 3 deletions crates/tabby/src/routes/completions.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::Arc;

use axum::{extract::State, Json};
use axum::{extract::State, Extension, Json};
use axum_extra::{headers, TypedHeader};
use hyper::StatusCode;
use tabby_common::axum::MaybeUser;
use tabby_common::axum::{AllowedCodeRepository, MaybeUser};
use tracing::{instrument, warn};

use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService};
Expand All @@ -25,6 +25,7 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet
#[instrument(skip(state, request))]
pub async fn completions(
State(state): State<Arc<CompletionService>>,
Extension(allowed_code_repository): Extension<AllowedCodeRepository>,
TypedHeader(MaybeUser(user)): TypedHeader<MaybeUser>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Json(mut request): Json<CompletionRequest>,
Expand All @@ -35,7 +36,10 @@ pub async fn completions(

let user_agent = user_agent.map(|x| x.0.to_string());

match state.generate(&request, user_agent.as_deref()).await {
match state
.generate(&request, &allowed_code_repository, user_agent.as_deref())
.await
{
Ok(resp) => Ok(Json(resp)),
Err(err) => {
warn!("{}", err);
Expand Down
32 changes: 17 additions & 15 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::{net::IpAddr, sync::Arc, time::Duration};

use axum::{routing, Router};
use axum::{routing, Extension, Router};
use clap::Args;
use hyper::StatusCode;
use spinners::{Spinner, Spinners, Stream};
use tabby_common::{
api::{self, code::CodeSearch, event::EventLogger},
config::{CodeRepositoryAccess, Config, ModelConfig, StaticCodeRepositoryAccess},
axum::AllowedCodeRepository,
config::{Config, ModelConfig},
usage,
};
use tabby_inference::ChatCompletionStream;
Expand Down Expand Up @@ -148,12 +149,10 @@ pub async fn main(config: &Config, args: &ServeArgs) {
};

let mut logger: Arc<dyn EventLogger> = Arc::new(create_event_logger());
let mut config_access: Arc<dyn CodeRepositoryAccess> = Arc::new(StaticCodeRepositoryAccess);

#[cfg(feature = "ee")]
if let Some(ws) = &ws {
logger = ws.logger();
config_access = ws.clone();
}

let index_reader_provider = Arc::new(IndexReaderProvider::default());
Expand All @@ -163,7 +162,6 @@ pub async fn main(config: &Config, args: &ServeArgs) {
));

let code = Arc::new(create_code_search(
config_access,
embedding.clone(),
index_reader_provider.clone(),
));
Expand Down Expand Up @@ -262,16 +260,20 @@ async fn api_router(
});

if let Some(completion_state) = completion_state {
routers.push({
Router::new()
.route(
"/v1/completions",
routing::post(routes::completions).with_state(Arc::new(completion_state)),
)
.layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,
)))
});
let mut router = Router::new()
.route(
"/v1/completions",
routing::post(routes::completions).with_state(Arc::new(completion_state)),
)
.layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,
)));

if webserver.is_none() || webserver.is_some_and(|x| !x) {
router = router.layer(Extension(AllowedCodeRepository::new_from_config()));
}

routers.push(router);
} else {
routers.push({
Router::new().route(
Expand Down
Loading

0 comments on commit 385d7ae

Please sign in to comment.