diff --git a/Cargo.lock b/Cargo.lock index dbfe0d9770a9..00740ec36eed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -234,6 +234,7 @@ dependencies = [ "bitflags 1.3.2", "bytes", "futures-util", + "headers", "http", "http-body", "hyper", @@ -1355,6 +1356,31 @@ dependencies = [ "hashbrown 0.14.0", ] +[[package]] +name = "headers" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" +dependencies = [ + "base64 0.13.1", + "bitflags 1.3.2", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.4.1" @@ -3777,6 +3803,7 @@ dependencies = [ "rusqlite_migration", "rust-embed 8.0.0", "serde", + "serde_json", "tabby-common", "tarpc", "thiserror", diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index 417d110f3f8b..ad095c8a682d 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -1,10 +1,9 @@ use std::{env::consts::ARCH, net::IpAddr, sync::Arc}; -use anyhow::Result; use axum::{routing, Router}; use clap::Args; -use tabby_webserver::api::{tracing_context, HubClient, WorkerKind}; -use tracing::{info, warn}; +use tabby_webserver::api::{HubClient, WorkerKind}; +use tracing::info; use crate::{ routes::{self, run_app}, @@ -47,9 +46,7 @@ pub struct WorkerArgs { parallelism: u8, } -async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router { - context.register(WorkerKind::Chat, args).await; - +async fn make_chat_route(args: &WorkerArgs) -> Router { let chat_state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await); @@ -60,8 +57,6 @@ async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router { } async fn make_completion_route(context: WorkerContext, args: &WorkerArgs) -> Router { - context.register(WorkerKind::Completion, args).await; - let code = Arc::new(context.client.clone()); let logger = Arc::new(context.client); let completion_state = Arc::new( @@ -79,11 +74,11 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) { info!("Starting worker, this might take a few minutes..."); - let context = WorkerContext::new(&args.url).await; + let context = WorkerContext::new(kind.clone(), args).await; let app = match kind { WorkerKind::Completion => make_completion_route(context, args).await, - WorkerKind::Chat => make_chat_route(context, args).await, + WorkerKind::Chat => make_chat_route(args).await, }; run_app(app, None, args.host, args.port).await @@ -94,25 +89,14 @@ struct WorkerContext { } impl WorkerContext { - async fn new(url: &str) -> Self { - Self { - client: tabby_webserver::api::create_client(url).await, - } - } - - async fn register(&self, kind: WorkerKind, args: &WorkerArgs) { - if let Err(err) = self.register_impl(kind, args).await { - warn!("Failed to register worker: {}", err) - } - } - - async fn register_impl(&self, kind: WorkerKind, args: &WorkerArgs) -> Result<()> { + async fn new(kind: WorkerKind, args: &WorkerArgs) -> Self { let (cpu_info, cpu_count) = read_cpu_info(); let cuda_devices = read_cuda_devices().unwrap_or_default(); - let worker = self - .client - .register_worker( - tracing_context(), + + Self { + client: tabby_webserver::api::create_client( + &args.url, + &args.token, kind, args.port as i32, args.model.to_owned(), @@ -121,12 +105,8 @@ impl WorkerContext { cpu_info, cpu_count as i32, cuda_devices, - args.token.clone(), ) - .await??; - - info!("Worker alive at {}", worker.addr); - - Ok(()) + .await, + } } } diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 41ab6101e10a..7e405293e45e 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -9,7 +9,7 @@ homepage.workspace = true anyhow.workspace = true argon2 = "0.5.1" async-trait.workspace = true -axum = { workspace = true, features = ["ws"] } +axum = { workspace = true, features = ["ws", "headers"] } bincode = "1.3.3" chrono = "0.4" futures.workspace = true @@ -26,6 +26,7 @@ rusqlite = { version = "0.30.0", features = ["bundled", "chrono"] } rusqlite_migration = { version = "1.1.0-beta.1", features = ["alpha-async-tokio-rusqlite", "from-directory"] } rust-embed = "8.0.0" serde.workspace = true +serde_json.workspace = true tabby-common = { path = "../../crates/tabby-common" } tarpc = { version = "0.33.0", features = ["serde-transport"] } thiserror.workspace = true diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs index f201a1c81c8e..0bab6d954403 100644 --- a/ee/tabby-webserver/src/api.rs +++ b/ee/tabby-webserver/src/api.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use hyper::Request; use tabby_common::api::{ code::{CodeSearch, CodeSearchError, SearchResponse}, event::RawEventLogger, @@ -6,22 +7,10 @@ use tabby_common::api::{ use tokio_tungstenite::connect_async; pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind}; -use crate::websocket::WebSocketTransport; +use crate::{websocket::WebSocketTransport, RegisterWorkerRequest, REGISTER_WORKER_HEADER}; #[tarpc::service] pub trait Hub { - async fn register_worker( - kind: WorkerKind, - port: i32, - name: String, - device: String, - arch: String, - cpu_info: String, - cpu_count: i32, - cuda_devices: Vec, - token: String, - ) -> Result; - async fn log_event(content: String); async fn search(q: String, limit: usize, offset: usize) -> SearchResponse; @@ -38,9 +27,45 @@ pub fn tracing_context() -> tarpc::context::Context { tarpc::context::current() } -pub async fn create_client(addr: &str) -> HubClient { - let addr = format!("ws://{}/hub", addr); - let (socket, _) = connect_async(&addr).await.unwrap(); +pub async fn create_client( + addr: &str, + token: &str, + kind: WorkerKind, + port: i32, + name: String, + device: String, + arch: String, + cpu_info: String, + cpu_count: i32, + cuda_devices: Vec, +) -> HubClient { + let request = Request::builder() + .uri(format!("ws://{}/hub", addr)) + .header("Host", addr) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", "unused") + .header("Authorization", format!("Bearer {}", token)) + .header("Content-Type", "application/json") + .header( + ®ISTER_WORKER_HEADER, + serde_json::to_string(&RegisterWorkerRequest { + kind, + port, + name, + device, + arch, + cpu_info, + cpu_count, + cuda_devices, + }) + .unwrap(), + ) + .body(()) + .unwrap(); + + let (socket, _) = connect_async(request).await.unwrap(); HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn() } diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index ec2fedcf95dc..d74adf7f53e5 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -3,12 +3,12 @@ pub mod api; mod schema; use api::Hub; pub use schema::create_schema; +use serde::{Deserialize, Serialize}; use tabby_common::api::{ code::{CodeSearch, SearchResponse}, event::RawEventLogger, }; -use tokio::sync::Mutex; -use tracing::{error, warn}; +use tracing::warn; use websocket::WebSocketTransport; mod repositories; @@ -20,15 +20,16 @@ use std::{net::SocketAddr, sync::Arc}; use axum::{ extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, - http::Request, + headers::Header, + http::{HeaderName, Request}, middleware::{from_fn_with_state, Next}, response::IntoResponse, - routing, Extension, Router, + routing, Extension, Router, TypedHeader, }; -use hyper::Body; -use juniper_axum::{graphiql, graphql, playground}; +use hyper::{Body, StatusCode}; +use juniper_axum::{extract::AuthBearer, graphiql, graphql, playground}; use schema::{ - worker::{RegisterWorkerError, Worker, WorkerKind}, + worker::{Worker, WorkerKind}, Schema, ServiceLocator, }; use service::create_service_locator; @@ -69,35 +70,103 @@ async fn distributed_tabby_layer( ws.worker().dispatch_request(request, next).await } +#[derive(Serialize, Deserialize)] +struct RegisterWorkerRequest { + kind: WorkerKind, + port: i32, + name: String, + device: String, + arch: String, + cpu_info: String, + cpu_count: i32, + cuda_devices: Vec, +} + +pub static REGISTER_WORKER_HEADER: HeaderName = HeaderName::from_static("x-tabby-register-worker"); + +impl Header for RegisterWorkerRequest { + fn name() -> &'static axum::http::HeaderName { + ®ISTER_WORKER_HEADER + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let mut x: Vec<_> = values + .map(|x| serde_json::from_slice(x.as_bytes())) + .collect(); + if let Some(x) = x.pop() { + x.map_err(|_| axum::headers::Error::invalid()) + } else { + Err(axum::headers::Error::invalid()) + } + } + + fn encode>(&self, _values: &mut E) { + todo!() + } +} + async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, + AuthBearer(token): AuthBearer, ConnectInfo(addr): ConnectInfo, + TypedHeader(request): TypedHeader, ) -> impl IntoResponse { - ws.on_upgrade(move |socket| handle_socket(state, socket, addr)) + let unauthorized = axum::response::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap() + .into_response(); + + let Some(token) = token else { + return unauthorized; + }; + + let Ok(registeration_token) = state.worker().read_registration_token().await else { + return unauthorized; + }; + + if token != registeration_token { + return unauthorized; + } + + let addr = format!("http://{}:{}", addr.ip(), request.port); + + let worker = Worker { + name: request.name, + kind: request.kind, + addr, + device: request.device, + arch: request.arch, + cpu_info: request.cpu_info, + cpu_count: request.cpu_count, + cuda_devices: request.cuda_devices, + }; + + ws.on_upgrade(move |socket| handle_socket(state, socket, worker)) + .into_response() } -async fn handle_socket(state: Arc, socket: WebSocket, addr: SocketAddr) { +async fn handle_socket(state: Arc, socket: WebSocket, worker: Worker) { let transport = WebSocketTransport::from(socket); let server = BaseChannel::with_defaults(transport); - let imp = Arc::new(HubImpl::new(state.clone(), addr)); + let imp = Arc::new(HubImpl::new(state.clone(), worker.addr.clone())); + state.worker().register_worker(worker).await.unwrap(); tokio::spawn(server.execute(imp.serve())).await.unwrap() } pub struct HubImpl { ctx: Arc, - conn: SocketAddr, - - worker_addr: Arc>, + worker_addr: String, } impl HubImpl { - pub fn new(ctx: Arc, conn: SocketAddr) -> Self { - Self { - ctx, - conn, - worker_addr: Arc::new(Mutex::new("".to_owned())), - } + pub fn new(ctx: Arc, worker_addr: String) -> Self { + Self { ctx, worker_addr } } } @@ -107,70 +176,13 @@ impl Drop for HubImpl { let worker_addr = self.worker_addr.clone(); tokio::spawn(async move { - let worker_addr = worker_addr.lock().await; - if !worker_addr.is_empty() { - ctx.worker().unregister_worker(worker_addr.as_str()).await; - } + ctx.worker().unregister_worker(worker_addr.as_str()).await; }); } } #[tarpc::server] impl Hub for Arc { - async fn register_worker( - self, - _context: tarpc::context::Context, - kind: WorkerKind, - port: i32, - name: String, - device: String, - arch: String, - cpu_info: String, - cpu_count: i32, - cuda_devices: Vec, - token: String, - ) -> Result { - if token.is_empty() { - return Err(RegisterWorkerError::InvalidToken( - "Empty worker token".to_string(), - )); - } - let server_token = match self.ctx.worker().read_registration_token().await { - Ok(t) => t, - Err(err) => { - error!("fetch server token: {}", err.to_string()); - return Err(RegisterWorkerError::InvalidToken( - "Failed to fetch server token".to_string(), - )); - } - }; - if server_token != token { - return Err(RegisterWorkerError::InvalidToken( - "Token mismatch".to_string(), - )); - } - - let mut worker_addr = self.worker_addr.lock().await; - if !worker_addr.is_empty() { - return Err(RegisterWorkerError::RegisterWorkerOnce); - } - - let addr = format!("http://{}:{}", self.conn.ip(), port); - *worker_addr = addr.clone(); - - let worker = Worker { - name, - kind, - addr, - device, - arch, - cpu_info, - cpu_count, - cuda_devices, - }; - self.ctx.worker().register_worker(worker).await - } - async fn log_event(self, _context: tarpc::context::Context, content: String) { self.ctx.logger().log(content) }