diff --git a/Cargo.lock b/Cargo.lock index 43342f8a414c..7cf78d794dfd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5060,6 +5060,8 @@ version = "0.15.0-dev.0" dependencies = [ "anyhow", "async-trait", + "axum", + "axum-extra", "chrono", "derive_builder", "hash-ids", @@ -5244,6 +5246,7 @@ dependencies = [ "anyhow", "argon2", "assert_matches", + "async-openai", "async-stream", "async-trait", "axum", @@ -5287,6 +5290,7 @@ dependencies = [ "tracing", "url", "urlencoding", + "utoipa", "uuid", ] diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index 8f0656681357..15c32235b75d 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -23,6 +23,8 @@ derive_builder.workspace = true hash-ids.workspace = true tracing.workspace = true chrono.workspace = true +axum.workspace = true +axum-extra = { workspace = true, features = ["typed-header"] } [dev-dependencies] temp_testdir = { workspace = true } diff --git a/crates/tabby-common/src/axum.rs b/crates/tabby-common/src/axum.rs new file mode 100644 index 000000000000..477d669de32b --- /dev/null +++ b/crates/tabby-common/src/axum.rs @@ -0,0 +1,31 @@ +use axum::http::HeaderName; +use axum_extra::headers::Header; + +use crate::constants::USER_HEADER_FIELD_NAME; + +#[derive(Debug)] +pub struct MaybeUser(pub Option); + +pub static USER_HEADER: HeaderName = HeaderName::from_static(USER_HEADER_FIELD_NAME); + +impl Header for MaybeUser { + fn name() -> &'static axum::http::HeaderName { + &USER_HEADER + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let Some(value) = values.next() else { + return Ok(MaybeUser(None)); + }; + let str = value.to_str().expect("User email is always a valid string"); + Ok(MaybeUser(Some(str.to_string()))) + } + + fn encode>(&self, _values: &mut E) { + todo!() + } +} diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index 87ed80c7776b..8d81958238d6 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -1,6 +1,7 @@ //! Common tabby types and utilities. //! Defines common types and utilities used across multiple tabby subprojects, especially serialization and deserialization targets. pub mod api; +pub mod axum; pub mod config; pub mod constants; pub mod index; diff --git a/crates/tabby/src/routes/chat.rs b/crates/tabby/src/routes/chat.rs index 60ba369be268..7ed9920c5f13 100644 --- a/crates/tabby/src/routes/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -8,11 +8,10 @@ use axum::{ use axum_extra::TypedHeader; use futures::{Stream, StreamExt}; use hyper::StatusCode; +use tabby_common::axum::MaybeUser; use tabby_inference::ChatCompletionStream; use tracing::{instrument, warn}; -use super::MaybeUser; - #[utoipa::path( post, path = "/v1/chat/completions", diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index d8a86556326b..07f73c372e16 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use axum::{extract::State, Json}; use axum_extra::TypedHeader; use hyper::StatusCode; +use tabby_common::axum::MaybeUser; use tracing::{instrument, warn}; -use super::MaybeUser; use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService}; #[utoipa::path( diff --git a/crates/tabby/src/routes/events.rs b/crates/tabby/src/routes/events.rs index a008e9e9ef34..0a044dcc58f3 100644 --- a/crates/tabby/src/routes/events.rs +++ b/crates/tabby/src/routes/events.rs @@ -6,9 +6,10 @@ use axum::{ }; use axum_extra::TypedHeader; use hyper::StatusCode; -use tabby_common::api::event::{Event, EventLogger, LogEventRequest, SelectKind}; - -use super::MaybeUser; +use tabby_common::{ + api::event::{Event, EventLogger, LogEventRequest, SelectKind}, + axum::MaybeUser, +}; #[utoipa::path( post, diff --git a/crates/tabby/src/routes/mod.rs b/crates/tabby/src/routes/mod.rs index 30ae15d33eaa..2025c512f315 100644 --- a/crates/tabby/src/routes/mod.rs +++ b/crates/tabby/src/routes/mod.rs @@ -5,10 +5,8 @@ use std::{ sync::Arc, }; -use axum::{http::HeaderName, routing, Router}; -use axum_extra::headers::Header; +use axum::{routing, Router}; use axum_prometheus::PrometheusMetricLayer; -use tabby_common::constants::USER_HEADER_FIELD_NAME; use tower_http::cors::CorsLayer; use crate::fatal; @@ -54,41 +52,12 @@ pub async fn run_app(api: Router, ui: Option, host: IpAddr, port: u16) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } -#[derive(Debug)] -pub(crate) struct MaybeUser(pub Option); - -pub(crate) static USER_HEADER: HeaderName = HeaderName::from_static(USER_HEADER_FIELD_NAME); - -impl Header for MaybeUser { - fn name() -> &'static axum::http::HeaderName { - &USER_HEADER - } - - fn decode<'i, I>(values: &mut I) -> Result - where - Self: Sized, - I: Iterator, - { - let Some(value) = values.next() else { - return Ok(MaybeUser(None)); - }; - let str = value.to_str().expect("User email is always a valid string"); - Ok(MaybeUser(Some(str.to_string()))) - } - - fn encode>(&self, _values: &mut E) { - todo!() - } -} - -mod answer; mod chat; mod completions; mod events; mod health; mod server_setting; -pub use answer::*; pub use chat::*; pub use completions::*; pub use events::*; diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 5a0bce29845d..0c5c8782d480 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -9,7 +9,7 @@ use tabby_common::{ config::{Config, ConfigAccess, ModelConfig, StaticConfigAccess}, usage, }; -use tabby_inference::Embedding; +use tabby_inference::ChatCompletionStream; use tokio::{sync::oneshot::Sender, time::sleep}; use tower_http::timeout::TimeoutLayer; use tracing::{debug, warn}; @@ -22,7 +22,7 @@ use utoipa_swagger_ui::SwaggerUi; use crate::{ routes::{self, run_app}, services::{ - self, answer, + self, code::create_code_search, completion::{self, create_completion_service}, embedding, @@ -51,7 +51,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi servers( (url = "/", description = "Server"), ), - paths(routes::log_event, routes::completions, routes::chat_completions_utoipa, routes::health, routes::answer, routes::setting), + paths(routes::log_event, routes::completions, routes::chat_completions_utoipa, routes::health, routes::setting), components(schemas( api::event::LogEventRequest, completion::CompletionRequest, @@ -67,8 +67,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi api::code::CodeSearchDocument, api::code::CodeSearchQuery, api::doc::DocSearchDocument, - answer::AnswerRequest, - answer::AnswerResponseChunk, api::server_setting::ServerSetting )), modifiers(&SecurityAddon), @@ -156,19 +154,29 @@ pub async fn main(config: &Config, args: &ServeArgs) { } let index_reader_provider = Arc::new(IndexReaderProvider::default()); + let docsearch = Arc::new(services::doc::create( + embedding.clone(), + index_reader_provider.clone(), + )); let code = Arc::new(create_code_search( config_access, embedding.clone(), index_reader_provider.clone(), )); + + let chat = if let Some(chat) = &config.model.chat { + Some(model::load_chat_completion(chat).await) + } else { + None + }; + let mut api = api_router( args, &config, logger.clone(), code.clone(), - embedding, - index_reader_provider, + chat.clone(), webserver, ) .await; @@ -178,7 +186,11 @@ pub async fn main(config: &Config, args: &ServeArgs) { #[cfg(feature = "ee")] if let Some(ws) = &ws { - let (new_api, new_ui) = ws.attach(api, ui, code, config.model.chat.is_some()).await; + let (new_api, new_ui) = ws + .attach(api, ui, code, chat, docsearch, |x| { + Box::new(services::doc::create_serper(x)) + }) + .await; api = new_api; ui = new_ui; }; @@ -210,8 +222,7 @@ async fn api_router( config: &Config, logger: Arc, code: Arc, - embedding: Arc, - index_reader_provider: Arc, + chat_state: Option>, webserver: Option, ) -> Router { let model = &config.model; @@ -224,22 +235,6 @@ async fn api_router( None }; - let chat_state = if let Some(chat) = &model.chat { - Some(model::load_chat_completion(chat).await) - } else { - None - }; - - let docsearch_state = Arc::new(services::doc::create(embedding, index_reader_provider)); - - let answer_state = chat_state.as_ref().map(|chat| { - Arc::new(services::answer::create( - chat.clone(), - code.clone(), - docsearch_state.clone(), - )) - }); - let mut routers = vec![]; let health_state = Arc::new(health::HealthState::new( @@ -318,19 +313,6 @@ async fn api_router( }); } - if let Some(answer_state) = answer_state { - routers.push({ - Router::new().route( - "/v1beta/answer", - routing::post(routes::answer).with_state(answer_state), - ) - }); - } else { - routers.push({ - Router::new().route("/v1beta/answer", routing::post(StatusCode::NOT_IMPLEMENTED)) - }); - } - let server_setting_router = Router::new().route("/v1beta/server_setting", routing::get(routes::setting)); diff --git a/crates/tabby/src/services/mod.rs b/crates/tabby/src/services/mod.rs index c0988c362e26..0c7d5892b252 100644 --- a/crates/tabby/src/services/mod.rs +++ b/crates/tabby/src/services/mod.rs @@ -1,4 +1,3 @@ -pub mod answer; pub mod code; pub mod completion; pub mod doc; diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 46bee3e21e09..315bd111d613 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -53,6 +53,8 @@ strum.workspace = true cron = "0.12.1" async-stream.workspace = true logkit.workspace = true +async-openai.workspace = true +utoipa.workspace = true [dev-dependencies] assert_matches.workspace = true diff --git a/crates/tabby/src/routes/answer.rs b/ee/tabby-webserver/src/routes/answer.rs similarity index 92% rename from crates/tabby/src/routes/answer.rs rename to ee/tabby-webserver/src/routes/answer.rs index 5e24474e9dd4..6edbe7a7e81b 100644 --- a/crates/tabby/src/routes/answer.rs +++ b/ee/tabby-webserver/src/routes/answer.rs @@ -8,10 +8,10 @@ use axum::{ }; use axum_extra::TypedHeader; use futures::{Stream, StreamExt}; +use tabby_common::axum::MaybeUser; use tracing::instrument; -use super::MaybeUser; -use crate::services::answer::{AnswerRequest, AnswerService}; +use crate::service::answer::{AnswerRequest, AnswerService}; #[utoipa::path( post, diff --git a/ee/tabby-webserver/src/routes/mod.rs b/ee/tabby-webserver/src/routes/mod.rs index ca28de94201c..425b8a79da67 100644 --- a/ee/tabby-webserver/src/routes/mod.rs +++ b/ee/tabby-webserver/src/routes/mod.rs @@ -1,3 +1,4 @@ +mod answer; mod hub; mod oauth; mod repositories; @@ -23,16 +24,32 @@ use self::hub::HubState; use crate::{ axum::{extract::AuthBearer, graphql, FromAuth}, jwt::validate_jwt, + service::answer::AnswerService, }; -pub fn create(ctx: Arc, api: Router, ui: Router) -> (Router, Router) { +pub fn create( + ctx: Arc, + api: Router, + ui: Router, + answer: Option>, +) -> (Router, Router) { let schema = Arc::new(create_schema()); - let api = api - .route( - "/v1beta/server_setting", - routing::get(server_setting).with_state(ctx.clone()), + let api = api.route( + "/v1beta/server_setting", + routing::get(server_setting).with_state(ctx.clone()), + ); + + let api = if let Some(answer) = answer { + api.route( + "/v1beta/answer", + routing::post(answer::answer).with_state(answer), ) + } else { + api.route("/v1beta/answer", routing::post(StatusCode::NOT_IMPLEMENTED)) + }; + + let api = api // Routes before `distributed_tabby_layer` are protected by authentication middleware for following routes: // 1. /v1/* // 2. /v1beta/* diff --git a/crates/tabby/src/services/answer.rs b/ee/tabby-webserver/src/service/answer.rs similarity index 98% rename from crates/tabby/src/services/answer.rs rename to ee/tabby-webserver/src/service/answer.rs index b09553dad9e3..0ad4ad04f19e 100644 --- a/crates/tabby/src/services/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -64,11 +64,12 @@ impl AnswerService { chat: Arc, code: Arc, doc: Arc, + serper_factory_fn: impl Fn(&str) -> Box, ) -> Self { let serper: Option> = if let Ok(api_key) = std::env::var("SERPER_API_KEY") { debug!("Serper API key found, enabling serper..."); - Some(Box::new(super::doc::create_serper(api_key.as_str()))) + Some(serper_factory_fn(&api_key)) } else { None }; @@ -363,8 +364,9 @@ pub fn create( chat: Arc, code: Arc, doc: Arc, + serper_factory_fn: impl Fn(&str) -> Box, ) -> AnswerService { - AnswerService::new(chat, code, doc) + AnswerService::new(chat, code, doc, serper_factory_fn) } fn get_content(message: &ChatCompletionRequestMessage) -> &str { diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 892942f6cf3d..ed09460e45a5 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -1,4 +1,5 @@ mod analytic; +pub mod answer; mod auth; pub mod background_job; mod email; diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs index c69e347df757..c6a97ccd98d9 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/webserver.rs @@ -4,12 +4,13 @@ use axum::Router; use tabby_common::{ api::{ code::CodeSearch, + doc::DocSearch, event::{ComposedLogger, EventLogger}, }, config::{Config, ConfigAccess, RepositoryConfig}, }; use tabby_db::DbConn; -use tabby_inference::Embedding; +use tabby_inference::{ChatCompletionStream, Embedding}; use tabby_schema::{ integration::IntegrationService, job::JobService, repository::RepositoryService, web_crawler::WebCrawlerService, @@ -96,11 +97,14 @@ impl Webserver { api: Router, ui: Router, code: Arc, - is_chat_enabled: bool, + chat: Option>, + docsearch: Arc, + serper_factory_fn: impl Fn(&str) -> Box, ) -> (Router, Router) { + let is_chat_enabled = chat.is_some(); let ctx = create_service_locator( self.logger(), - code, + code.clone(), self.repository.clone(), self.integration.clone(), self.web_crawler.clone(), @@ -110,6 +114,15 @@ impl Webserver { ) .await; - routes::create(ctx, api, ui) + let answer = chat.as_ref().map(|chat| { + Arc::new(crate::service::answer::create( + chat.clone(), + code.clone(), + docsearch.clone(), + serper_factory_fn, + )) + }); + + routes::create(ctx, api, ui, answer) } }