Skip to content

Commit

Permalink
Merge pull request #1060 from TabbyML/extract-public-mod
Browse files Browse the repository at this point in the history
refactor(webserver): extract tabby_webserver::public to centralize it…
  • Loading branch information
wsxiaoys authored Dec 16, 2023
2 parents 4df74f1 + 1f00e4a commit 3a33249
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 66 deletions.
4 changes: 2 additions & 2 deletions crates/tabby/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ async fn main() {
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
#[cfg(feature = "ee")]
Commands::WorkerCompletion(args) => {
worker::main(tabby_webserver::api::WorkerKind::Completion, args).await
worker::main(tabby_webserver::public::WorkerKind::Completion, args).await
}
#[cfg(feature = "ee")]
Commands::WorkerChat(args) => {
worker::main(tabby_webserver::api::WorkerKind::Chat, args).await
worker::main(tabby_webserver::public::WorkerKind::Chat, args).await
}
}

Expand Down
22 changes: 12 additions & 10 deletions crates/tabby/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{env::consts::ARCH, net::IpAddr, sync::Arc};

use axum::{routing, Router};
use clap::Args;
use tabby_webserver::api::{HubClient, WorkerKind};
use tabby_webserver::public::{HubClient, RegisterWorkerRequest, WorkerKind};
use tracing::info;

use crate::{
Expand Down Expand Up @@ -94,17 +94,19 @@ impl WorkerContext {
let cuda_devices = read_cuda_devices().unwrap_or_default();

Self {
client: tabby_webserver::api::create_client(
client: tabby_webserver::public::create_client(
&args.url,
&args.token,
kind,
args.port as i32,
args.model.to_owned(),
args.device.to_string(),
ARCH.to_string(),
cpu_info,
cpu_count as i32,
cuda_devices,
RegisterWorkerRequest {
kind,
port: args.port as i32,
name: args.model.to_owned(),
device: args.device.to_string(),
arch: ARCH.to_string(),
cpu_info,
cpu_count: cpu_count as i32,
cuda_devices,
},
)
.await,
}
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-webserver/examples/update-schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fs::write;

use tabby_webserver::create_schema;
use tabby_webserver::public::create_schema;

fn main() {
let schema = create_schema();
Expand Down
56 changes: 56 additions & 0 deletions ee/tabby-webserver/src/handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use std::sync::Arc;

use axum::{
extract::State,
http::Request,
middleware::{from_fn_with_state, Next},
routing, Extension, Router,
};
use hyper::Body;
use juniper_axum::{graphiql, graphql, playground};
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};

use crate::{
hub, repositories,
schema::{create_schema, Schema, ServiceLocator},
service::create_service_locator,
ui,
};

pub async fn attach_webserver(
api: Router,
ui: Router,
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> (Router, Router) {
let ctx = create_service_locator(logger, code).await;
let schema = Arc::new(create_schema());

let api = api
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
.route(
"/graphql",
routing::post(graphql::<Arc<Schema>, Arc<dyn ServiceLocator>>).with_state(ctx.clone()),
)
.route("/graphql", routing::get(playground("/graphql", None)))
.layer(Extension(schema))
.route(
"/hub",
routing::get(hub::ws_handler).with_state(ctx.clone()),
)
.nest("/repositories", repositories::routes(ctx.clone()));

let ui = ui
.route("/graphiql", routing::get(graphiql("/graphql", None)))
.fallback(ui::handler);

(api, ui)
}

async fn distributed_tabby_layer(
State(ws): State<Arc<dyn ServiceLocator>>,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
ws.worker().dispatch_request(request, next).await
}
48 changes: 13 additions & 35 deletions ee/tabby-webserver/src/hub/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tabby_common::api::{
use tokio_tungstenite::connect_async;

use super::websocket::WebSocketTransport;
pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind};
pub use crate::schema::worker::WorkerKind;

#[tarpc::service]
pub trait Hub {
Expand All @@ -29,18 +29,7 @@ pub fn tracing_context() -> tarpc::context::Context {
tarpc::context::current()
}

pub async fn create_client(
addr: &str,
token: &str,
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> HubClient {
pub async fn create_client(addr: &str, token: &str, request: RegisterWorkerRequest) -> HubClient {
let request = Request::builder()
.uri(format!("ws://{}/hub", addr))
.header("Host", addr)
Expand All @@ -52,17 +41,7 @@ pub async fn create_client(
.header("Content-Type", "application/json")
.header(
&REGISTER_WORKER_HEADER,
serde_json::to_string(&RegisterWorkerRequest {
kind,
port,
name,
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
})
.unwrap(),
serde_json::to_string(&request).unwrap(),
)
.body(())
.unwrap();
Expand Down Expand Up @@ -121,19 +100,18 @@ impl CodeSearch for HubClient {
}

#[derive(Serialize, Deserialize)]
pub(crate) struct RegisterWorkerRequest {
pub(crate) kind: WorkerKind,
pub(crate) port: i32,
pub(crate) name: String,
pub(crate) device: String,
pub(crate) arch: String,
pub(crate) cpu_info: String,
pub(crate) cpu_count: i32,
pub(crate) cuda_devices: Vec<String>,
pub struct RegisterWorkerRequest {
pub kind: WorkerKind,
pub port: i32,
pub name: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}

pub(crate) static REGISTER_WORKER_HEADER: HeaderName =
HeaderName::from_static("x-tabby-register-worker");
pub static REGISTER_WORKER_HEADER: HeaderName = HeaderName::from_static("x-tabby-register-worker");

impl Header for RegisterWorkerRequest {
fn name() -> &'static axum::http::HeaderName {
Expand Down
18 changes: 6 additions & 12 deletions ee/tabby-webserver/src/hub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,20 @@ mod websocket;

use std::{net::SocketAddr, sync::Arc};

use api::{Hub, RegisterWorkerRequest};
use axum::{
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
headers::Header,
response::IntoResponse,
TypedHeader,
};
use hyper::{Body, StatusCode};
use juniper_axum::extract::AuthBearer;
use tabby_common::api::{
code::{CodeSearch, SearchResponse},
event::RawEventLogger,
};
use tabby_common::api::code::SearchResponse;
use tarpc::server::{BaseChannel, Channel};
use tracing::warn;
use websocket::WebSocketTransport;

use self::websocket::WebSocketTransport;
use crate::{
api::{Hub, RegisterWorkerRequest},
schema::{worker::Worker, ServiceLocator},
};
use crate::schema::{worker::Worker, ServiceLocator};

pub(crate) async fn ws_handler(
ws: WebSocketUpgrade,
Expand Down Expand Up @@ -74,13 +68,13 @@ async fn handle_socket(state: Arc<dyn ServiceLocator>, socket: WebSocket, worker
tokio::spawn(server.execute(imp.serve())).await.unwrap()
}

pub struct HubImpl {
struct HubImpl {
ctx: Arc<dyn ServiceLocator>,
worker_addr: String,
}

impl HubImpl {
pub fn new(ctx: Arc<dyn ServiceLocator>, worker_addr: String) -> Self {
fn new(ctx: Arc<dyn ServiceLocator>, worker_addr: String) -> Self {
Self { ctx, worker_addr }
}
}
Expand Down
17 changes: 11 additions & 6 deletions ee/tabby-webserver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
// used by tabby workers.
pub use hub::api;
// used by examples/update-schema.rs
pub use schema::create_schema;

mod handler;
mod hub;
mod repositories;
mod schema;
mod service;
mod ui;

pub mod public {
pub use super::{
handler::attach_webserver,
/* used by tabby workers (consumer of /hub api) */
hub::api::{create_client, HubClient, RegisterWorkerRequest, WorkerKind},
/* used by examples/update-schema.rs */ schema::create_schema,
};
}

use std::sync::Arc;

use axum::{
Expand All @@ -30,7 +35,7 @@ pub async fn attach_webserver(
code: Arc<dyn CodeSearch>,
) -> (Router, Router) {
let ctx = create_service_locator(logger, code).await;
let schema = Arc::new(create_schema());
let schema = Arc::new(schema::create_schema());

let api = api
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
Expand Down

0 comments on commit 3a33249

Please sign in to comment.