Skip to content

Commit

Permalink
refactor(webserver): move answer routes to webserver (#2704)
Browse files Browse the repository at this point in the history
* refactor(webserver): move answer routes to webserver

* fix
  • Loading branch information
wsxiaoys authored Jul 23, 2024
1 parent 04bcd3f commit b7fea15
Show file tree
Hide file tree
Showing 16 changed files with 114 additions and 91 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/tabby-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ derive_builder.workspace = true
hash-ids.workspace = true
tracing.workspace = true
chrono.workspace = true
axum.workspace = true
axum-extra = { workspace = true, features = ["typed-header"] }

[dev-dependencies]
temp_testdir = { workspace = true }
Expand Down
31 changes: 31 additions & 0 deletions crates/tabby-common/src/axum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use axum::http::HeaderName;
use axum_extra::headers::Header;

use crate::constants::USER_HEADER_FIELD_NAME;

#[derive(Debug)]
pub struct MaybeUser(pub Option<String>);

pub static USER_HEADER: HeaderName = HeaderName::from_static(USER_HEADER_FIELD_NAME);

impl Header for MaybeUser {
fn name() -> &'static axum::http::HeaderName {
&USER_HEADER
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum_extra::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let Some(value) = values.next() else {
return Ok(MaybeUser(None));
};
let str = value.to_str().expect("User email is always a valid string");
Ok(MaybeUser(Some(str.to_string())))
}

fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
todo!()
}
}
1 change: 1 addition & 0 deletions crates/tabby-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Common tabby types and utilities.
//! Defines common types and utilities used across multiple tabby subprojects, especially serialization and deserialization targets.
pub mod api;
pub mod axum;
pub mod config;
pub mod constants;
pub mod index;
Expand Down
3 changes: 1 addition & 2 deletions crates/tabby/src/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use axum::{
use axum_extra::TypedHeader;
use futures::{Stream, StreamExt};
use hyper::StatusCode;
use tabby_common::axum::MaybeUser;
use tabby_inference::ChatCompletionStream;
use tracing::{instrument, warn};

use super::MaybeUser;

#[utoipa::path(
post,
path = "/v1/chat/completions",
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::sync::Arc;
use axum::{extract::State, Json};
use axum_extra::TypedHeader;
use hyper::StatusCode;
use tabby_common::axum::MaybeUser;
use tracing::{instrument, warn};

use super::MaybeUser;
use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService};

#[utoipa::path(
Expand Down
7 changes: 4 additions & 3 deletions crates/tabby/src/routes/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use axum::{
};
use axum_extra::TypedHeader;
use hyper::StatusCode;
use tabby_common::api::event::{Event, EventLogger, LogEventRequest, SelectKind};

use super::MaybeUser;
use tabby_common::{
api::event::{Event, EventLogger, LogEventRequest, SelectKind},
axum::MaybeUser,
};

#[utoipa::path(
post,
Expand Down
33 changes: 1 addition & 32 deletions crates/tabby/src/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ use std::{
sync::Arc,
};

use axum::{http::HeaderName, routing, Router};
use axum_extra::headers::Header;
use axum::{routing, Router};
use axum_prometheus::PrometheusMetricLayer;
use tabby_common::constants::USER_HEADER_FIELD_NAME;
use tower_http::cors::CorsLayer;

use crate::fatal;
Expand Down Expand Up @@ -54,41 +52,12 @@ pub async fn run_app(api: Router, ui: Option<Router>, host: IpAddr, port: u16) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}

#[derive(Debug)]
pub(crate) struct MaybeUser(pub Option<String>);

pub(crate) static USER_HEADER: HeaderName = HeaderName::from_static(USER_HEADER_FIELD_NAME);

impl Header for MaybeUser {
fn name() -> &'static axum::http::HeaderName {
&USER_HEADER
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum_extra::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let Some(value) = values.next() else {
return Ok(MaybeUser(None));
};
let str = value.to_str().expect("User email is always a valid string");
Ok(MaybeUser(Some(str.to_string())))
}

fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
todo!()
}
}

mod answer;
mod chat;
mod completions;
mod events;
mod health;
mod server_setting;

pub use answer::*;
pub use chat::*;
pub use completions::*;
pub use events::*;
Expand Down
60 changes: 21 additions & 39 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tabby_common::{
config::{Config, ConfigAccess, ModelConfig, StaticConfigAccess},
usage,
};
use tabby_inference::Embedding;
use tabby_inference::ChatCompletionStream;
use tokio::{sync::oneshot::Sender, time::sleep};
use tower_http::timeout::TimeoutLayer;
use tracing::{debug, warn};
Expand All @@ -22,7 +22,7 @@ use utoipa_swagger_ui::SwaggerUi;
use crate::{
routes::{self, run_app},
services::{
self, answer,
self,
code::create_code_search,
completion::{self, create_completion_service},
embedding,
Expand Down Expand Up @@ -51,7 +51,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers(
(url = "/", description = "Server"),
),
paths(routes::log_event, routes::completions, routes::chat_completions_utoipa, routes::health, routes::answer, routes::setting),
paths(routes::log_event, routes::completions, routes::chat_completions_utoipa, routes::health, routes::setting),
components(schemas(
api::event::LogEventRequest,
completion::CompletionRequest,
Expand All @@ -67,8 +67,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
api::code::CodeSearchDocument,
api::code::CodeSearchQuery,
api::doc::DocSearchDocument,
answer::AnswerRequest,
answer::AnswerResponseChunk,
api::server_setting::ServerSetting
)),
modifiers(&SecurityAddon),
Expand Down Expand Up @@ -156,19 +154,29 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}

let index_reader_provider = Arc::new(IndexReaderProvider::default());
let docsearch = Arc::new(services::doc::create(
embedding.clone(),
index_reader_provider.clone(),
));

let code = Arc::new(create_code_search(
config_access,
embedding.clone(),
index_reader_provider.clone(),
));

let chat = if let Some(chat) = &config.model.chat {
Some(model::load_chat_completion(chat).await)
} else {
None
};

let mut api = api_router(
args,
&config,
logger.clone(),
code.clone(),
embedding,
index_reader_provider,
chat.clone(),
webserver,
)
.await;
Expand All @@ -178,7 +186,11 @@ pub async fn main(config: &Config, args: &ServeArgs) {

#[cfg(feature = "ee")]
if let Some(ws) = &ws {
let (new_api, new_ui) = ws.attach(api, ui, code, config.model.chat.is_some()).await;
let (new_api, new_ui) = ws
.attach(api, ui, code, chat, docsearch, |x| {
Box::new(services::doc::create_serper(x))
})
.await;
api = new_api;
ui = new_ui;
};
Expand Down Expand Up @@ -210,8 +222,7 @@ async fn api_router(
config: &Config,
logger: Arc<dyn EventLogger>,
code: Arc<dyn CodeSearch>,
embedding: Arc<dyn Embedding>,
index_reader_provider: Arc<IndexReaderProvider>,
chat_state: Option<Arc<dyn ChatCompletionStream>>,
webserver: Option<bool>,
) -> Router {
let model = &config.model;
Expand All @@ -224,22 +235,6 @@ async fn api_router(
None
};

let chat_state = if let Some(chat) = &model.chat {
Some(model::load_chat_completion(chat).await)
} else {
None
};

let docsearch_state = Arc::new(services::doc::create(embedding, index_reader_provider));

let answer_state = chat_state.as_ref().map(|chat| {
Arc::new(services::answer::create(
chat.clone(),
code.clone(),
docsearch_state.clone(),
))
});

let mut routers = vec![];

let health_state = Arc::new(health::HealthState::new(
Expand Down Expand Up @@ -318,19 +313,6 @@ async fn api_router(
});
}

if let Some(answer_state) = answer_state {
routers.push({
Router::new().route(
"/v1beta/answer",
routing::post(routes::answer).with_state(answer_state),
)
});
} else {
routers.push({
Router::new().route("/v1beta/answer", routing::post(StatusCode::NOT_IMPLEMENTED))
});
}

let server_setting_router =
Router::new().route("/v1beta/server_setting", routing::get(routes::setting));

Expand Down
1 change: 0 additions & 1 deletion crates/tabby/src/services/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pub mod answer;
pub mod code;
pub mod completion;
pub mod doc;
Expand Down
2 changes: 2 additions & 0 deletions ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ strum.workspace = true
cron = "0.12.1"
async-stream.workspace = true
logkit.workspace = true
async-openai.workspace = true
utoipa.workspace = true

[dev-dependencies]
assert_matches.workspace = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use axum::{
};
use axum_extra::TypedHeader;
use futures::{Stream, StreamExt};
use tabby_common::axum::MaybeUser;
use tracing::instrument;

use super::MaybeUser;
use crate::services::answer::{AnswerRequest, AnswerService};
use crate::service::answer::{AnswerRequest, AnswerService};

#[utoipa::path(
post,
Expand Down
27 changes: 22 additions & 5 deletions ee/tabby-webserver/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod answer;
mod hub;
mod oauth;
mod repositories;
Expand All @@ -23,16 +24,32 @@ use self::hub::HubState;
use crate::{
axum::{extract::AuthBearer, graphql, FromAuth},
jwt::validate_jwt,
service::answer::AnswerService,
};

pub fn create(ctx: Arc<dyn ServiceLocator>, api: Router, ui: Router) -> (Router, Router) {
pub fn create(
ctx: Arc<dyn ServiceLocator>,
api: Router,
ui: Router,
answer: Option<Arc<AnswerService>>,
) -> (Router, Router) {
let schema = Arc::new(create_schema());

let api = api
.route(
"/v1beta/server_setting",
routing::get(server_setting).with_state(ctx.clone()),
let api = api.route(
"/v1beta/server_setting",
routing::get(server_setting).with_state(ctx.clone()),
);

let api = if let Some(answer) = answer {
api.route(
"/v1beta/answer",
routing::post(answer::answer).with_state(answer),
)
} else {
api.route("/v1beta/answer", routing::post(StatusCode::NOT_IMPLEMENTED))
};

let api = api
// Routes before `distributed_tabby_layer` are protected by authentication middleware for following routes:
// 1. /v1/*
// 2. /v1beta/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ impl AnswerService {
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
serper_factory_fn: impl Fn(&str) -> Box<dyn DocSearch>,
) -> Self {
let serper: Option<Box<dyn DocSearch>> =
if let Ok(api_key) = std::env::var("SERPER_API_KEY") {
debug!("Serper API key found, enabling serper...");
Some(Box::new(super::doc::create_serper(api_key.as_str())))
Some(serper_factory_fn(&api_key))
} else {
None
};
Expand Down Expand Up @@ -363,8 +364,9 @@ pub fn create(
chat: Arc<dyn ChatCompletionStream>,
code: Arc<dyn CodeSearch>,
doc: Arc<dyn DocSearch>,
serper_factory_fn: impl Fn(&str) -> Box<dyn DocSearch>,
) -> AnswerService {
AnswerService::new(chat, code, doc)
AnswerService::new(chat, code, doc, serper_factory_fn)
}

fn get_content(message: &ChatCompletionRequestMessage) -> &str {
Expand Down
Loading

0 comments on commit b7fea15

Please sign in to comment.